File size: 3,510 Bytes
28dbb9e
 
 
 
 
48fb238
 
 
28dbb9e
 
 
 
 
 
 
 
 
 
48fb238
28dbb9e
 
 
 
 
 
 
 
 
48fb238
 
28dbb9e
 
 
 
 
 
 
48fb238
 
 
 
 
 
28dbb9e
48fb238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28dbb9e
48fb238
 
 
28dbb9e
48fb238
 
 
 
28dbb9e
 
48fb238
 
 
 
28dbb9e
48fb238
 
 
28dbb9e
48fb238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import streamlit as st
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import neurokit2 as nk

# Custom activation functions required by the model
def sin_activation(x):
    return tf.math.sin(x)

def cos_activation(x):
    return tf.math.cos(x)

# Load model with custom objects
@st.cache_resource
def load_model():
    return tf.keras.models.load_model(
        "model.keras",
        custom_objects={
            'sin': sin_activation,
            'cos': cos_activation,
            'gelu': tf.keras.activations.gelu
        }
    )

model = load_model()

# AAMI class mapping matching training code
CLASS_MAP = {
    0: "Normal",
    1: "Supraventricular Ectopic (SVEB)",
    2: "Ventricular Ectopic (VEB)",
    3: "Fusion Beat",
    4: "Unknown"
}

def process_mitbih_file(dat_file):
    """Process MIT-BIH .dat file using NeuroKit2"""
    try:
        # Read raw signal data (assuming format 16, 360Hz, gain=200)
        signal = np.frombuffer(dat_file.getbuffer(), dtype=np.int16).astype(np.float32)
        signal /= 200.0  # Convert to mV using standard gain
        
        # NeuroKit2 processing with assumed 360Hz sampling rate
        ecg_signals, info = nk.ecg_process(signal, sampling_rate=360)
        r_peaks = info["ECG_R_Peaks"]
        
        # Extract beats with same parameters as training
        window_size = 257
        beats = []
        for r in r_peaks:
            start = max(0, r - window_size//2)
            end = start + window_size
            if end <= len(signal):
                beat = signal[start:end]
                beats.append(beat)
                
        return np.array(beats)
    except Exception as e:
        st.error(f"File processing error: {str(e)}")
        return None

# Streamlit UI
st.title("ECG Arrhythmia Detection")
st.write("Upload MIT-BIH .dat file")

uploaded_file = st.file_uploader(
    "Select .dat file",
    type=["dat"],
    accept_multiple_files=False
)

if uploaded_file is not None:
    if st.button("Analyze"):
        with st.spinner("Processing ECG signal..."):
            beats = process_mitbih_file(uploaded_file)
            
            if beats is not None and len(beats) > 0:
                # Prepare data for model
                beats = beats.reshape((-1, 257, 1)).astype(np.float32)
                
                # Make predictions
                predictions = model.predict(beats)
                pred_classes = np.argmax(predictions, axis=1)
                
                # Show results
                st.subheader("Analysis Results")
                results = pd.DataFrame({
                    "Beat Index": range(len(beats)),
                    "Predicted Class": [CLASS_MAP[c] for c in pred_classes],
                    "Confidence": [f"{np.max(p):.1%}" for p in predictions]
                })
                st.dataframe(results)
                
                # Visualizations
                st.subheader("ECG Signal")
                fig, ax = plt.subplots(1, 2, figsize=(15, 4))
                ax[0].plot(beats[0].flatten())
                ax[0].set_title("Sample ECG Beat")
                class_dist = results["Predicted Class"].value_counts()
                ax[1].bar(class_dist.index, class_dist.values)
                ax[1].set_title("Class Distribution")
                ax[1].tick_params(axis='x', rotation=45)
                st.pyplot(fig)
            else:
                st.error("Failed to extract valid beats from the signal")