Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import sounddevice as sd | |
| import soundfile as sf | |
| from transformers import pipeline | |
| # Load the model pipeline | |
| model = pipeline("audio-classification", model="HareemFatima/distilhubert-finetuned-stutterdetection") | |
| # Define a function to map predicted labels to types of stuttering | |
| def map_label_to_stutter_type(label): | |
| if label == 0: | |
| return "nonstutter" | |
| elif label == 1: | |
| return "prolongation" | |
| elif label == 2: | |
| return "repetition" | |
| elif label == 3: | |
| return "blocks" | |
| else: | |
| return "Unknown" | |
| # Function to classify audio input and return the stutter type | |
| def classify_audio(audio_input): | |
| # Call your model pipeline to classify the audio | |
| prediction = model(audio_input) | |
| # Get the predicted label | |
| predicted_label = prediction[0]["label"] | |
| # Map the label to the corresponding stutter type | |
| stutter_type = map_label_to_stutter_type(predicted_label) | |
| return stutter_type | |
| # Streamlit app | |
| def main(): | |
| st.title("Stutter Classification App") | |
| audio_input = st.audio("Capture Audio", format="audio/wav", start_recording=True, channels=1) | |
| if st.button("Stop Recording"): | |
| sd.stop() | |
| with st.spinner("Classifying..."): | |
| # Read the recorded audio file | |
| recording_path = "recording.wav" | |
| audio_data, sampling_rate = sf.read(recording_path) | |
| # Classify the audio | |
| stutter_type = classify_audio(audio_data) | |
| st.write("Predicted Stutter Type:", stutter_type) | |
| if __name__ == "__main__": | |
| main() | |