lukiod commited on
Commit
08468cc
·
verified ·
1 Parent(s): 01e9544

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -69
app.py CHANGED
@@ -2,10 +2,13 @@ import streamlit as st
2
  import numpy as np
3
  import pandas as pd
4
  import tensorflow as tf
 
 
 
 
5
  import matplotlib.pyplot as plt
6
- import neurokit2 as nk
7
 
8
- # Custom activation functions required by the model
9
  def sin_activation(x):
10
  return tf.math.sin(x)
11
 
@@ -26,8 +29,8 @@ def load_model():
26
 
27
  model = load_model()
28
 
29
- # AAMI class mapping matching training code
30
- CLASS_MAP = {
31
  0: "Normal",
32
  1: "Supraventricular Ectopic (SVEB)",
33
  2: "Ventricular Ectopic (VEB)",
@@ -35,73 +38,82 @@ CLASS_MAP = {
35
  4: "Unknown"
36
  }
37
 
38
- def process_mitbih_file(dat_file):
39
- """Process MIT-BIH .dat file using NeuroKit2"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  try:
41
- # Read raw signal data (assuming format 16, 360Hz, gain=200)
42
- signal = np.frombuffer(dat_file.getbuffer(), dtype=np.int16).astype(np.float32)
43
- signal /= 200.0 # Convert to mV using standard gain
44
-
45
- # NeuroKit2 processing with assumed 360Hz sampling rate
46
- ecg_signals, info = nk.ecg_process(signal, sampling_rate=360)
47
- r_peaks = info["ECG_R_Peaks"]
48
-
49
- # Extract beats with same parameters as training
50
- window_size = 257
51
- beats = []
52
- for r in r_peaks:
53
- start = max(0, r - window_size//2)
54
- end = start + window_size
55
- if end <= len(signal):
56
- beat = signal[start:end]
57
- beats.append(beat)
58
-
59
- return np.array(beats)
60
  except Exception as e:
61
- st.error(f"File processing error: {str(e)}")
62
- return None
63
-
64
- # Streamlit UI
65
- st.title("ECG Arrhythmia Detection")
66
- st.write("Upload MIT-BIH .dat file")
67
 
68
- uploaded_file = st.file_uploader(
69
- "Select .dat file",
70
- type=["dat"],
71
- accept_multiple_files=False
 
72
  )
73
 
74
- if uploaded_file is not None:
75
- if st.button("Analyze"):
76
- with st.spinner("Processing ECG signal..."):
77
- beats = process_mitbih_file(uploaded_file)
78
-
79
- if beats is not None and len(beats) > 0:
80
- # Prepare data for model
81
- beats = beats.reshape((-1, 257, 1)).astype(np.float32)
82
-
83
- # Make predictions
84
- predictions = model.predict(beats)
85
- pred_classes = np.argmax(predictions, axis=1)
86
-
87
- # Show results
88
- st.subheader("Analysis Results")
89
- results = pd.DataFrame({
90
- "Beat Index": range(len(beats)),
91
- "Predicted Class": [CLASS_MAP[c] for c in pred_classes],
92
- "Confidence": [f"{np.max(p):.1%}" for p in predictions]
93
- })
94
- st.dataframe(results)
95
-
96
- # Visualizations
97
- st.subheader("ECG Signal")
98
- fig, ax = plt.subplots(1, 2, figsize=(15, 4))
99
- ax[0].plot(beats[0].flatten())
100
- ax[0].set_title("Sample ECG Beat")
101
- class_dist = results["Predicted Class"].value_counts()
102
- ax[1].bar(class_dist.index, class_dist.values)
103
- ax[1].set_title("Class Distribution")
104
- ax[1].tick_params(axis='x', rotation=45)
105
- st.pyplot(fig)
106
- else:
107
- st.error("Failed to extract valid beats from the signal")
 
 
 
 
 
 
 
2
  import numpy as np
3
  import pandas as pd
4
  import tensorflow as tf
5
+ import wfdb
6
+ import tempfile
7
+ import os
8
+ from scipy.signal import resample
9
  import matplotlib.pyplot as plt
 
10
 
11
+ # Custom activation functions
12
  def sin_activation(x):
13
  return tf.math.sin(x)
14
 
 
29
 
30
  model = load_model()
31
 
32
+ # AAMI class map
33
+ class_map = {
34
  0: "Normal",
35
  1: "Supraventricular Ectopic (SVEB)",
36
  2: "Ventricular Ectopic (VEB)",
 
38
  4: "Unknown"
39
  }
40
 
41
+ def extract_beats(record, annotation, window_size=257):
42
+ beats = []
43
+ r_locs = annotation.sample
44
+ signal = record.p_signal[:, 0] # Using first channel
45
+
46
+ for r in r_locs:
47
+ start = max(0, r - window_size//2)
48
+ end = min(len(signal), r + window_size//2 + 1)
49
+
50
+ if end - start == window_size:
51
+ beat = signal[start:end]
52
+ beats.append(beat)
53
+
54
+ return np.array(beats)
55
+
56
+ st.title("ECG Arrhythmia Classification")
57
+ st.write("Upload MIT-BIH record files (.dat, .hea, .atr) or load record 108")
58
+
59
+ record_loaded = False
60
+ record = None
61
+ annotation = None
62
+
63
+ # Load Record 108 Button
64
+ if st.button("Load Record 108"):
65
  try:
66
+ base_name = "108"
67
+ record = wfdb.rdrecord(base_name)
68
+ annotation = wfdb.rdann(base_name, 'atr')
69
+ record_loaded = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  except Exception as e:
71
+ st.error(f"Error loading Record 108: {str(e)}")
 
 
 
 
 
72
 
73
+ # File uploader
74
+ uploaded_files = st.file_uploader(
75
+ "Or upload your own files",
76
+ type=["dat", "hea", "atr"],
77
+ accept_multiple_files=True
78
  )
79
 
80
+ if uploaded_files and not record_loaded:
81
+ with tempfile.TemporaryDirectory() as tmpdir:
82
+ for f in uploaded_files:
83
+ file_path = os.path.join(tmpdir, f.name)
84
+ with open(file_path, "wb") as f_out:
85
+ f_out.write(f.getbuffer())
86
+
87
+ base_names = {os.path.splitext(f.name)[0] for f in uploaded_files}
88
+ common_base = list(base_names)[0]
89
+
90
+ try:
91
+ record = wfdb.rdrecord(os.path.join(tmpdir, common_base))
92
+ annotation = wfdb.rdann(os.path.join(tmpdir, common_base), 'atr')
93
+ record_loaded = True
94
+ except Exception as e:
95
+ st.error(f"Error reading uploaded files: {str(e)}")
96
+
97
+ # Run processing if record is loaded
98
+ if record_loaded and record is not None and annotation is not None:
99
+ beats = extract_beats(record, annotation)
100
+ if len(beats) == 0:
101
+ st.error("No valid beats found in the record")
102
+ st.stop()
103
+
104
+ beats = beats.reshape((-1, 257, 1)).astype(np.float32)
105
+ predictions = model.predict(beats)
106
+ predicted_classes = np.argmax(predictions, axis=1)
107
+
108
+ st.subheader("Classification Results")
109
+ results = pd.DataFrame({
110
+ "Beat Index": range(len(beats)),
111
+ "Predicted Class": [class_map[c] for c in predicted_classes],
112
+ "Confidence": np.max(predictions, axis=1)
113
+ })
114
+ st.dataframe(results)
115
+
116
+ st.subheader("Sample ECG Beat")
117
+ fig, ax = plt.subplots()
118
+ ax.plot(beats[0].flatten())
119
+ st.pyplot(fig)