Update app.py
Browse files
app.py
CHANGED
|
@@ -16,18 +16,18 @@ from collections import Counter
|
|
| 16 |
device = torch.device("cpu")
|
| 17 |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
|
| 18 |
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
|
| 19 |
-
|
| 20 |
-
# model_path = '
|
| 21 |
-
model_path =
|
| 22 |
|
| 23 |
-
if os.path.exists(model_path):
|
| 24 |
-
|
| 25 |
-
|
| 26 |
|
| 27 |
|
| 28 |
-
title = "Upload an mp3 file for
|
| 29 |
description = """
|
| 30 |
-
The model was trained on Thai audio recordings with the following sentences
|
| 31 |
ชาวไร่ตัดต้นสนทำท่อนซุง\n
|
| 32 |
ปูม้าวิ่งไปมาบนใบไม้ (เน้นใช้ริมฝีปาก)\n
|
| 33 |
อีกาคอยคาบงูคาบไก่ (เน้นใช้เพดานปาก)\n
|
|
@@ -39,7 +39,13 @@ The model was trained on Thai audio recordings with the following sentences, so
|
|
| 39 |
<img src="https://huggingface.co/spaces/course-demos/Rick_and_Morty_QA/resolve/main/rick.png" width=200px>
|
| 40 |
"""
|
| 41 |
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
model.eval()
|
| 44 |
with torch.no_grad():
|
| 45 |
wav_data, _ = sf.read(file_path.name)
|
|
@@ -56,44 +62,15 @@ def actualpredict(file_path):
|
|
| 56 |
logits = model(**inputs).logits
|
| 57 |
logits = logits.squeeze()
|
| 58 |
predicted_class_id = torch.argmax(logits, dim=-1).item()
|
| 59 |
-
return predicted_class_id
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def predict(file_upload):
|
| 63 |
-
|
| 64 |
-
max_length = 100000
|
| 65 |
-
warn_output = " "
|
| 66 |
-
ans = " "
|
| 67 |
-
# file_path = file_upload
|
| 68 |
-
# if (microphone is not None) and (file_upload is not None):
|
| 69 |
-
# warn_output = (
|
| 70 |
-
# "WARNING: You've uploaded an audio file and used the microphone. "
|
| 71 |
-
# "The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
|
| 72 |
-
# )
|
| 73 |
-
|
| 74 |
-
# elif (microphone is None) and (file_upload is None):
|
| 75 |
-
# return "ERROR: You have to either use the microphone or upload an audio file"
|
| 76 |
-
# if(microphone is not None):
|
| 77 |
-
# file_path = microphone
|
| 78 |
-
# if(file_upload is not None):
|
| 79 |
-
# file_path = file_upload
|
| 80 |
|
| 81 |
-
predicted_class_id = actualpredict(file_upload)
|
| 82 |
-
if(predicted_class_id==0):
|
| 83 |
-
ans = "no_parkinson"
|
| 84 |
-
else:
|
| 85 |
-
ans = "parkinson"
|
| 86 |
return predicted_class_id
|
| 87 |
gr.Interface(
|
| 88 |
fn=predict,
|
| 89 |
-
inputs=
|
| 90 |
-
gr.inputs.Audio(source="upload", type="filepath", optional=True),
|
| 91 |
-
],
|
| 92 |
outputs="text",
|
| 93 |
title=title,
|
| 94 |
description=description,
|
| 95 |
).launch()
|
| 96 |
|
| 97 |
-
# gr.inputs.Audio(source="microphone", type="filepath", optional=True),
|
| 98 |
# iface = gr.Interface(fn=predict, inputs="file", outputs="text")
|
| 99 |
# iface.launch()
|
|
|
|
| 16 |
device = torch.device("cpu")
|
| 17 |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
|
| 18 |
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
|
| 19 |
+
model_path = "dysarthria_classifier12.pth"
|
| 20 |
+
# model_path = '/home/user/app/dysarthria_classifier12.pth'
|
| 21 |
+
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
| 22 |
|
| 23 |
+
# if os.path.exists(model_path):
|
| 24 |
+
# print(f"Loading saved model {model_path}")
|
| 25 |
+
# model.load_state_dict(torch.load(model_path))
|
| 26 |
|
| 27 |
|
| 28 |
+
title = "Upload an mp3 file for parkinsons detection! (Thai Language)"
|
| 29 |
description = """
|
| 30 |
+
The model was trained on Thai audio recordings with the following sentences: \n
|
| 31 |
ชาวไร่ตัดต้นสนทำท่อนซุง\n
|
| 32 |
ปูม้าวิ่งไปมาบนใบไม้ (เน้นใช้ริมฝีปาก)\n
|
| 33 |
อีกาคอยคาบงูคาบไก่ (เน้นใช้เพดานปาก)\n
|
|
|
|
| 39 |
<img src="https://huggingface.co/spaces/course-demos/Rick_and_Morty_QA/resolve/main/rick.png" width=200px>
|
| 40 |
"""
|
| 41 |
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def predict(file_path):
|
| 47 |
+
max_length = 100000
|
| 48 |
+
|
| 49 |
model.eval()
|
| 50 |
with torch.no_grad():
|
| 51 |
wav_data, _ = sf.read(file_path.name)
|
|
|
|
| 62 |
logits = model(**inputs).logits
|
| 63 |
logits = logits.squeeze()
|
| 64 |
predicted_class_id = torch.argmax(logits, dim=-1).item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
return predicted_class_id
|
| 67 |
gr.Interface(
|
| 68 |
fn=predict,
|
| 69 |
+
inputs="file",
|
|
|
|
|
|
|
| 70 |
outputs="text",
|
| 71 |
title=title,
|
| 72 |
description=description,
|
| 73 |
).launch()
|
| 74 |
|
|
|
|
| 75 |
# iface = gr.Interface(fn=predict, inputs="file", outputs="text")
|
| 76 |
# iface.launch()
|