MoHamdyy commited on
Commit
2d3625d
·
1 Parent(s): 6bdb513

Fix syntax error in TTS stage and complete pipeline

Browse files
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -380,8 +380,18 @@ try:
380
  TTS_MODEL.load_state_dict(state["state_dict"])
381
  else:
382
  TTS_MODEL.load_state_dict(state) # Assume the whole file is the state_dict
 
383
  TTS_MODEL.eval()
384
 
 
 
 
 
 
 
 
 
 
385
  # Try torch.compile for additional speedup (PyTorch 2.0+)
386
  try:
387
  TTS_MODEL = torch.compile(TTS_MODEL, mode="reduce-overhead")
@@ -389,10 +399,7 @@ try:
389
  except Exception as compile_error:
390
  print(f"Torch compile not available: {compile_error}, using standard model.")
391
 
392
- print("TTS model loaded successfully.")
393
- except Exception as e:
394
- print(f"Error loading TTS model: {e}")
395
- TTS_MODEL = None
396
 
397
  # Load STT (Whisper) Model from Hub
398
  try:
 
380
  TTS_MODEL.load_state_dict(state["state_dict"])
381
  else:
382
  TTS_MODEL.load_state_dict(state) # Assume the whole file is the state_dict
383
+
384
  TTS_MODEL.eval()
385
 
386
+ # Set all submodules to eval mode and disable gradients permanently
387
+ for module in TTS_MODEL.modules():
388
+ if hasattr(module, 'training'):
389
+ module.train(False)
390
+
391
+ # Disable gradients for all parameters permanently
392
+ for param in TTS_MODEL.parameters():
393
+ param.requires_grad = False
394
+
395
  # Try torch.compile for additional speedup (PyTorch 2.0+)
396
  try:
397
  TTS_MODEL = torch.compile(TTS_MODEL, mode="reduce-overhead")
 
399
  except Exception as compile_error:
400
  print(f"Torch compile not available: {compile_error}, using standard model.")
401
 
402
+ print("TTS model loaded successfully (optimized for inference).")
 
 
 
403
 
404
  # Load STT (Whisper) Model from Hub
405
  try: