Spaces:
Sleeping
Sleeping
File size: 5,138 Bytes
b10f2fc 08fbee0 b10f2fc 08fbee0 b10f2fc 08fbee0 b10f2fc 08fbee0 b10f2fc 08fbee0 b10f2fc 08fbee0 b10f2fc 08fbee0 b10f2fc 08fbee0 b10f2fc 08fbee0 b10f2fc 08fbee0 b10f2fc 08fbee0 b10f2fc 08fbee0 b10f2fc 08fbee0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | import streamlit as st
from transformers import pipeline
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
from config import MODEL_ID
# Load the model and pipeline using the model_id variable
pipe = pipeline("audio-classification", model=MODEL_ID)
def classify_audio(filepath):
preds = pipe(filepath)
outputs = {"normal": 0.0, "artifact": 0.0, "murmur": 0.0}
for p in preds:
label = p["label"]
# Simplify the labels and accumulate the scores
if "artifact" in label:
outputs["artifact"] += p["score"]
elif "murmur" in label:
outputs["murmur"] += p["score"]
elif "extra" in label or "Normal" in label:
outputs["normal"] += p["score"]
return outputs
# Streamlit app layout
st.title("Heartbeat Sound Classification")
# Theme selection
theme = st.sidebar.selectbox(
"Select Theme",
["Light Green", "Light Blue"]
)
# Add custom CSS for styling based on the selected theme
if theme == "Light Green":
st.markdown(
"""
<style>
body, .stApp {
background-color: #e8f5e9;
}
.stApp {
color: #004d40;
}
.stButton > button, .stFileUpload > div {
background-color: #004d40;
color: white;
}
.stButton > button:hover, .stFileUpload > div:hover {
background-color: #00332c;
}
</style>
""",
unsafe_allow_html=True
)
elif theme == "Light Blue":
st.markdown(
"""
<style>
body, .stApp {
background-color: #e0f7fa;
}
.stApp {
color: #006064;
}
.stButton > button, .stFileUpload > div {
background-color: #006064;
color: white;
}
.stButton > button:hover, .stFileUpload > div:hover {
background-color: #004d40;
}
</style>
""",
unsafe_allow_html=True
)
# File uploader for audio files
uploaded_file = st.file_uploader("Upload an audio file", type=["wav", "mp3"])
if uploaded_file is not None:
st.subheader("Uploaded Audio File")
# Load and display the audio file
audio_bytes = uploaded_file.read()
st.audio(audio_bytes, format='audio/wav')
# Save the uploaded file to a temporary location
with open("temp_audio_file.wav", "wb") as f:
f.write(audio_bytes)
# Load audio for visualization
waveform, sample_rate = torchaudio.load("temp_audio_file.wav")
# Visualization selection
viz_type = st.radio("Select visualization type:", ["Waveform", "Spectrogram"])
# Create visualization
fig, ax = plt.subplots(figsize=(10, 4))
if viz_type == "Waveform":
time = np.arange(waveform.shape[1]) / sample_rate
ax.plot(time, waveform[0].numpy())
ax.set_title("Audio Waveform")
ax.set_xlabel("Time (s)")
ax.set_ylabel("Amplitude")
ax.set_xlim([0, time[-1]])
else:
ax.specgram(waveform[0].numpy(), Fs=sample_rate, cmap='viridis', NFFT=1024, noverlap=512)
ax.set_title("Spectrogram")
ax.set_xlabel("Time (s)")
ax.set_ylabel("Frequency (Hz)")
st.pyplot(fig)
# Classify the audio file
st.write("Classifying the audio...")
results = classify_audio("temp_audio_file.wav")
# Display the classification results
st.subheader("Classification Results")
results_box = st.empty()
results_str = "\n".join([f"{label}: {score:.2f}" for label, score in results.items()])
results_box.text(results_str)
# Sample Audio Files for classification
st.write("Sample Audio Files:")
examples = ['normal.wav', 'murmur.wav', 'extra_systole.wav', 'extra_hystole.wav', 'artifact.wav']
for example in examples:
if st.button(example):
st.subheader(f"Sample Audio: {example}")
audio_bytes = open(example, 'rb').read()
st.audio(audio_bytes, format='audio/wav')
# Load audio for visualization
waveform, sample_rate = torchaudio.load(example)
# Visualization selection
viz_type = st.radio("Select visualization type:", ["Waveform", "Spectrogram"], key=example)
# Create visualization
fig, ax = plt.subplots(figsize=(10, 4))
if viz_type == "Waveform":
time = np.arange(waveform.shape[1]) / sample_rate
ax.plot(time, waveform[0].numpy())
ax.set_title("Audio Waveform")
ax.set_xlabel("Time (s)")
ax.set_ylabel("Amplitude")
ax.set_xlim([0, time[-1]])
else:
ax.specgram(waveform[0].numpy(), Fs=sample_rate, cmap='viridis', NFFT=1024, noverlap=512)
ax.set_title("Spectrogram")
ax.set_xlabel("Time (s)")
ax.set_ylabel("Frequency (Hz)")
st.pyplot(fig)
# Classification results
results = classify_audio(example)
st.write("Results:")
results_str = "\n".join([f"{label}: {score:.2f}" for label, score in results.items()])
st.text(results_str) |