File size: 3,974 Bytes
28dbb9e
 
 
 
08468cc
 
 
 
28dbb9e
48fb238
08468cc
28dbb9e
 
 
 
 
 
 
 
 
 
48fb238
28dbb9e
 
 
 
 
 
 
 
08468cc
 
28dbb9e
 
 
 
 
 
 
08468cc
 
 
 
fd3c0db
08468cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48fb238
08468cc
 
 
 
48fb238
08468cc
28dbb9e
08468cc
 
 
 
 
28dbb9e
 
08468cc
 
 
 
 
 
fd3c0db
08468cc
 
fd3c0db
08468cc
 
 
 
 
 
 
 
 
 
 
 
 
fd3c0db
08468cc
 
 
 
 
 
 
 
 
 
 
 
fd3c0db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08468cc
 
 
fd3c0db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import streamlit as st
import numpy as np
import pandas as pd
import tensorflow as tf
import wfdb
import tempfile
import os
from scipy.signal import resample
import matplotlib.pyplot as plt

# Custom activation functions
def sin_activation(x):
    return tf.math.sin(x)

def cos_activation(x):
    return tf.math.cos(x)

# Load model with custom objects
@st.cache_resource
def load_model():
    return tf.keras.models.load_model(
        "model.keras",
        custom_objects={
            'sin': sin_activation,
            'cos': cos_activation,
            'gelu': tf.keras.activations.gelu
        }
    )
model = load_model()

# AAMI class map
class_map = {
    0: "Normal",
    1: "Supraventricular Ectopic (SVEB)",
    2: "Ventricular Ectopic (VEB)",
    3: "Fusion Beat",
    4: "Unknown"
}

def extract_beats(record, annotation, window_size=257):
    beats = []
    r_locs = annotation.sample
    signal = record.p_signal[:, 0]  # Using first channel
    
    for r in r_locs:
        start = max(0, r - window_size//2)
        end = min(len(signal), r + window_size//2 + 1)
        if end - start == window_size:
            beat = signal[start:end]
            beats.append(beat)
    return np.array(beats)

st.title("ECG Arrhythmia Classification")
st.write("Upload MIT-BIH record files (.dat, .hea, .atr) or load record 108")
record_loaded = False
record = None
annotation = None

# Load Record 108 Button
if st.button("Load Record 108"):
    try:
        base_name = "108"
        record = wfdb.rdrecord(base_name)
        annotation = wfdb.rdann(base_name, 'atr')
        record_loaded = True
    except Exception as e:
        st.error(f"Error loading Record 108: {str(e)}")

# File uploader
uploaded_files = st.file_uploader(
    "Or upload your own files",
    type=["dat", "hea", "atr"],
    accept_multiple_files=True
)

if uploaded_files and not record_loaded:
    with tempfile.TemporaryDirectory() as tmpdir:
        for f in uploaded_files:
            file_path = os.path.join(tmpdir, f.name)
            with open(file_path, "wb") as f_out:
                f_out.write(f.getbuffer())
        
        base_names = {os.path.splitext(f.name)[0] for f in uploaded_files}
        common_base = list(base_names)[0]
        
        try:
            record = wfdb.rdrecord(os.path.join(tmpdir, common_base))
            annotation = wfdb.rdann(os.path.join(tmpdir, common_base), 'atr')
            record_loaded = True
        except Exception as e:
            st.error(f"Error reading uploaded files: {str(e)}")

# Run processing if record is loaded
if record_loaded and record is not None and annotation is not None:
    beats = extract_beats(record, annotation)
    if len(beats) == 0:
        st.error("No valid beats found in the record")
        st.stop()
        
    beats = beats.reshape((-1, 257, 1)).astype(np.float32)
    predictions = model.predict(beats)
    predicted_classes = np.argmax(predictions, axis=1)

    st.subheader("Classification Results")
    results = pd.DataFrame({
        "Beat Index": range(len(beats)),
        "Predicted Class": [class_map[c] for c in predicted_classes],
        "Confidence": np.max(predictions, axis=1)
    })
    st.dataframe(results)

    # Class Distribution Section
    st.subheader("Class Distribution")
    
    # Get counts for all classes
    class_indices = list(class_map.keys())
    class_names = [class_map[i] for i in class_indices]
    counts = [np.sum(predicted_classes == i) for i in class_indices]
    
    # Create distribution dataframe
    distribution_df = pd.DataFrame({
        "Class": class_names,
        "Count": counts
    })

    # Display in two columns
    col1, col2 = st.columns([1, 2])
    with col1:
        st.dataframe(distribution_df.style.format({'Count': '{:,}'}))
    
    with col2:
        st.bar_chart(distribution_df.set_index('Class'))

    st.subheader("Sample ECG Beat")
    fig, ax = plt.subplots()
    ax.plot(beats[0].flatten())
    st.pyplot(fig)