| import gradio as gr | |
| import torch | |
| import soundfile as sf | |
| import os | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import os | |
| import soundfile as sf | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader | |
| from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification | |
| from sklearn.model_selection import train_test_split | |
| import re | |
| from collections import Counter | |
| from sklearn.metrics import classification_report | |
| model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device) | |
| model_path = "dysarthria_classifier12.pth" | |
| if os.path.exists(model_path): | |
| print(f"Loading saved model {model_path}") | |
| model.load_state_dict(torch.load(model_path)) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") | |
| def predict(file_path): | |
| max_length = 100000 | |
| model.eval() | |
| with torch.no_grad(): | |
| wav_data, _ = sf.read(file_path.name) | |
| inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True) | |
| input_values = inputs.input_values.squeeze(0) | |
| if max_length - input_values.shape[-1] > 0: | |
| input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1) | |
| else: | |
| input_values = input_values[:max_length] | |
| input_values = input_values.unsqueeze(0).to(device) | |
| inputs = {"input_values": input_values} | |
| logits = model(**inputs).logits | |
| logits = logits.squeeze() | |
| predicted_class_id = torch.argmax(logits, dim=-1).item() | |
| return predicted_class_id | |
| iface = gr.Interface(fn=predict, inputs="file", outputs="text") | |
| iface.launch() | |