import streamlit as st from transformers import pipeline import torchaudio import numpy as np import matplotlib.pyplot as plt from config import MODEL_ID # Load the model and pipeline using the model_id variable pipe = pipeline("audio-classification", model=MODEL_ID) def classify_audio(filepath): preds = pipe(filepath) outputs = {"normal": 0.0, "artifact": 0.0, "murmur": 0.0} for p in preds: label = p["label"] # Simplify the labels and accumulate the scores if "artifact" in label: outputs["artifact"] += p["score"] elif "murmur" in label: outputs["murmur"] += p["score"] elif "extra" in label or "Normal" in label: outputs["normal"] += p["score"] return outputs # Streamlit app layout st.title("Heartbeat Sound Classification") # Theme selection theme = st.sidebar.selectbox( "Select Theme", ["Light Green", "Light Blue"] ) # Add custom CSS for styling based on the selected theme if theme == "Light Green": st.markdown( """ """, unsafe_allow_html=True ) elif theme == "Light Blue": st.markdown( """ """, unsafe_allow_html=True ) # File uploader for audio files uploaded_file = st.file_uploader("Upload an audio file", type=["wav", "mp3"]) if uploaded_file is not None: st.subheader("Uploaded Audio File") # Load and display the audio file audio_bytes = uploaded_file.read() st.audio(audio_bytes, format='audio/wav') # Save the uploaded file to a temporary location with open("temp_audio_file.wav", "wb") as f: f.write(audio_bytes) # Load audio for visualization waveform, sample_rate = torchaudio.load("temp_audio_file.wav") # Visualization selection viz_type = st.radio("Select visualization type:", ["Waveform", "Spectrogram"]) # Create visualization fig, ax = plt.subplots(figsize=(10, 4)) if viz_type == "Waveform": time = np.arange(waveform.shape[1]) / sample_rate ax.plot(time, waveform[0].numpy()) ax.set_title("Audio Waveform") ax.set_xlabel("Time (s)") ax.set_ylabel("Amplitude") ax.set_xlim([0, time[-1]]) else: ax.specgram(waveform[0].numpy(), Fs=sample_rate, cmap='viridis', NFFT=1024, noverlap=512) ax.set_title("Spectrogram") ax.set_xlabel("Time (s)") ax.set_ylabel("Frequency (Hz)") st.pyplot(fig) # Classify the audio file st.write("Classifying the audio...") results = classify_audio("temp_audio_file.wav") # Display the classification results st.subheader("Classification Results") results_box = st.empty() results_str = "\n".join([f"{label}: {score:.2f}" for label, score in results.items()]) results_box.text(results_str) # Sample Audio Files for classification st.write("Sample Audio Files:") examples = ['normal.wav', 'murmur.wav', 'extra_systole.wav', 'extra_hystole.wav', 'artifact.wav'] for example in examples: if st.button(example): st.subheader(f"Sample Audio: {example}") audio_bytes = open(example, 'rb').read() st.audio(audio_bytes, format='audio/wav') # Load audio for visualization waveform, sample_rate = torchaudio.load(example) # Visualization selection viz_type = st.radio("Select visualization type:", ["Waveform", "Spectrogram"], key=example) # Create visualization fig, ax = plt.subplots(figsize=(10, 4)) if viz_type == "Waveform": time = np.arange(waveform.shape[1]) / sample_rate ax.plot(time, waveform[0].numpy()) ax.set_title("Audio Waveform") ax.set_xlabel("Time (s)") ax.set_ylabel("Amplitude") ax.set_xlim([0, time[-1]]) else: ax.specgram(waveform[0].numpy(), Fs=sample_rate, cmap='viridis', NFFT=1024, noverlap=512) ax.set_title("Spectrogram") ax.set_xlabel("Time (s)") ax.set_ylabel("Frequency (Hz)") st.pyplot(fig) # Classification results results = classify_audio(example) st.write("Results:") results_str = "\n".join([f"{label}: {score:.2f}" for label, score in results.items()]) st.text(results_str)