Update app.py
Browse files
app.py
CHANGED
|
@@ -16,14 +16,13 @@ from collections import Counter
|
|
| 16 |
device = torch.device("cpu")
|
| 17 |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
|
| 18 |
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
|
| 19 |
-
|
| 20 |
# model_path = 'model_weights2.pth'
|
| 21 |
-
model_path = '/home/user/app/dysarthria_classifier10.pth'
|
| 22 |
-
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
|
| 28 |
|
| 29 |
title = "Upload an mp3 file for parkinsons detection! (Thai Language)"
|
|
@@ -65,13 +64,13 @@ def predict(file_path):
|
|
| 65 |
predicted_class_id = torch.argmax(logits, dim=-1).item()
|
| 66 |
|
| 67 |
return predicted_class_id
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
|
| 76 |
-
iface = gr.Interface(fn=predict, inputs="file", outputs="text")
|
| 77 |
-
iface.launch()
|
|
|
|
| 16 |
device = torch.device("cpu")
|
| 17 |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
|
| 18 |
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
|
| 19 |
+
model_path = "dysarthria_classifier12.pth"
|
| 20 |
# model_path = 'model_weights2.pth'
|
| 21 |
+
# model_path = '/home/user/app/dysarthria_classifier10.pth'
|
|
|
|
| 22 |
|
| 23 |
+
if os.path.exists(model_path):
|
| 24 |
+
print(f"Loading saved model {model_path}")
|
| 25 |
+
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
| 26 |
|
| 27 |
|
| 28 |
title = "Upload an mp3 file for parkinsons detection! (Thai Language)"
|
|
|
|
| 64 |
predicted_class_id = torch.argmax(logits, dim=-1).item()
|
| 65 |
|
| 66 |
return predicted_class_id
|
| 67 |
+
gr.Interface(
|
| 68 |
+
fn=predict,
|
| 69 |
+
inputs="file",
|
| 70 |
+
outputs="text",
|
| 71 |
+
title=title,
|
| 72 |
+
description=description,
|
| 73 |
+
).launch()
|
| 74 |
|
| 75 |
+
# iface = gr.Interface(fn=predict, inputs="file", outputs="text")
|
| 76 |
+
# iface.launch()
|