File size: 5,138 Bytes
b10f2fc
 
 
08fbee0
 
b10f2fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08fbee0
b10f2fc
 
08fbee0
b10f2fc
 
08fbee0
 
b10f2fc
 
08fbee0
b10f2fc
 
 
 
 
 
 
 
 
 
08fbee0
b10f2fc
 
08fbee0
b10f2fc
 
08fbee0
 
b10f2fc
 
08fbee0
b10f2fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08fbee0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b10f2fc
 
 
 
08fbee0
b10f2fc
 
 
 
 
 
 
 
 
 
 
 
 
08fbee0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b10f2fc
 
 
08fbee0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import streamlit as st
from transformers import pipeline
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
from config import MODEL_ID

# Load the model and pipeline using the model_id variable
pipe = pipeline("audio-classification", model=MODEL_ID)

def classify_audio(filepath):
    preds = pipe(filepath)
    outputs = {"normal": 0.0, "artifact": 0.0, "murmur": 0.0}
    for p in preds:
        label = p["label"]
        # Simplify the labels and accumulate the scores
        if "artifact" in label:
            outputs["artifact"] += p["score"]
        elif "murmur" in label:
            outputs["murmur"] += p["score"]
        elif "extra" in label or "Normal" in label:
            outputs["normal"] += p["score"]
    return outputs

# Streamlit app layout
st.title("Heartbeat Sound Classification")

# Theme selection
theme = st.sidebar.selectbox(
    "Select Theme",
    ["Light Green", "Light Blue"]
)

# Add custom CSS for styling based on the selected theme
if theme == "Light Green":
    st.markdown(
        """
        <style>
        body, .stApp {
            background-color: #e8f5e9;
        }
        .stApp {
            color: #004d40;
        }
        .stButton > button, .stFileUpload > div {
            background-color: #004d40;
            color: white;
        }
        .stButton > button:hover, .stFileUpload > div:hover {
            background-color: #00332c;
        }
        </style>
        """,
        unsafe_allow_html=True
    )
elif theme == "Light Blue":
    st.markdown(
        """
        <style>
        body, .stApp {
            background-color: #e0f7fa;
        }
        .stApp {
            color: #006064;
        }
        .stButton > button, .stFileUpload > div {
            background-color: #006064;
            color: white;
        }
        .stButton > button:hover, .stFileUpload > div:hover {
            background-color: #004d40;
        }
        </style>
        """,
        unsafe_allow_html=True
    )

# File uploader for audio files
uploaded_file = st.file_uploader("Upload an audio file", type=["wav", "mp3"])

if uploaded_file is not None:
    st.subheader("Uploaded Audio File")
    # Load and display the audio file
    audio_bytes = uploaded_file.read()
    st.audio(audio_bytes, format='audio/wav')

    # Save the uploaded file to a temporary location
    with open("temp_audio_file.wav", "wb") as f:
        f.write(audio_bytes)
    
    # Load audio for visualization
    waveform, sample_rate = torchaudio.load("temp_audio_file.wav")
    
    # Visualization selection
    viz_type = st.radio("Select visualization type:", ["Waveform", "Spectrogram"])
    
    # Create visualization
    fig, ax = plt.subplots(figsize=(10, 4))
    if viz_type == "Waveform":
        time = np.arange(waveform.shape[1]) / sample_rate
        ax.plot(time, waveform[0].numpy())
        ax.set_title("Audio Waveform")
        ax.set_xlabel("Time (s)")
        ax.set_ylabel("Amplitude")
        ax.set_xlim([0, time[-1]])
    else:
        ax.specgram(waveform[0].numpy(), Fs=sample_rate, cmap='viridis', NFFT=1024, noverlap=512)
        ax.set_title("Spectrogram")
        ax.set_xlabel("Time (s)")
        ax.set_ylabel("Frequency (Hz)")
    
    st.pyplot(fig)
    
    # Classify the audio file
    st.write("Classifying the audio...")
    results = classify_audio("temp_audio_file.wav")
    
    # Display the classification results
    st.subheader("Classification Results")
    results_box = st.empty()
    results_str = "\n".join([f"{label}: {score:.2f}" for label, score in results.items()])
    results_box.text(results_str)

# Sample Audio Files for classification
st.write("Sample Audio Files:")
examples = ['normal.wav', 'murmur.wav', 'extra_systole.wav', 'extra_hystole.wav', 'artifact.wav']
for example in examples:
    if st.button(example):
        st.subheader(f"Sample Audio: {example}")
        audio_bytes = open(example, 'rb').read()
        st.audio(audio_bytes, format='audio/wav')
        
        # Load audio for visualization
        waveform, sample_rate = torchaudio.load(example)
        
        # Visualization selection
        viz_type = st.radio("Select visualization type:", ["Waveform", "Spectrogram"], key=example)
        
        # Create visualization
        fig, ax = plt.subplots(figsize=(10, 4))
        if viz_type == "Waveform":
            time = np.arange(waveform.shape[1]) / sample_rate
            ax.plot(time, waveform[0].numpy())
            ax.set_title("Audio Waveform")
            ax.set_xlabel("Time (s)")
            ax.set_ylabel("Amplitude")
            ax.set_xlim([0, time[-1]])
        else:
            ax.specgram(waveform[0].numpy(), Fs=sample_rate, cmap='viridis', NFFT=1024, noverlap=512)
            ax.set_title("Spectrogram")
            ax.set_xlabel("Time (s)")
            ax.set_ylabel("Frequency (Hz)")
        
        st.pyplot(fig)
        
        # Classification results
        results = classify_audio(example)
        st.write("Results:")
        results_str = "\n".join([f"{label}: {score:.2f}" for label, score in results.items()])
        st.text(results_str)