throgletworld commited on
Commit
4322133
·
verified ·
1 Parent(s): fb9af37

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -90,8 +90,14 @@ def load_models():
90
  checkpoint_path = "wavlm_stutter_classification_best.pth"
91
  if os.path.exists(checkpoint_path):
92
  checkpoint = torch.load(checkpoint_path, map_location=device)
93
- wavlm_model.load_state_dict(checkpoint['model_state_dict'])
94
- print(f"Loaded checkpoint with {checkpoint.get('val_accuracy', 'N/A')} accuracy")
 
 
 
 
 
 
95
  else:
96
  print("WARNING: No checkpoint found, using random weights")
97
 
 
90
  checkpoint_path = "wavlm_stutter_classification_best.pth"
91
  if os.path.exists(checkpoint_path):
92
  checkpoint = torch.load(checkpoint_path, map_location=device)
93
+ # Handle both formats: direct state_dict OR wrapped in 'model_state_dict'
94
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
95
+ wavlm_model.load_state_dict(checkpoint['model_state_dict'])
96
+ print(f"Loaded checkpoint with {checkpoint.get('val_accuracy', 'N/A')} accuracy")
97
+ else:
98
+ # Direct state_dict (how train_waveLM.py saves it)
99
+ wavlm_model.load_state_dict(checkpoint)
100
+ print("Loaded checkpoint (direct state_dict format)")
101
  else:
102
  print("WARNING: No checkpoint found, using random weights")
103