Upload app.py
Browse files
app.py
CHANGED
|
@@ -90,8 +90,14 @@ def load_models():
|
|
| 90 |
checkpoint_path = "wavlm_stutter_classification_best.pth"
|
| 91 |
if os.path.exists(checkpoint_path):
|
| 92 |
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
else:
|
| 96 |
print("WARNING: No checkpoint found, using random weights")
|
| 97 |
|
|
|
|
| 90 |
checkpoint_path = "wavlm_stutter_classification_best.pth"
|
| 91 |
if os.path.exists(checkpoint_path):
|
| 92 |
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 93 |
+
# Handle both formats: direct state_dict OR wrapped in 'model_state_dict'
|
| 94 |
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 95 |
+
wavlm_model.load_state_dict(checkpoint['model_state_dict'])
|
| 96 |
+
print(f"Loaded checkpoint with {checkpoint.get('val_accuracy', 'N/A')} accuracy")
|
| 97 |
+
else:
|
| 98 |
+
# Direct state_dict (how train_waveLM.py saves it)
|
| 99 |
+
wavlm_model.load_state_dict(checkpoint)
|
| 100 |
+
print("Loaded checkpoint (direct state_dict format)")
|
| 101 |
else:
|
| 102 |
print("WARNING: No checkpoint found, using random weights")
|
| 103 |
|