lukiod commited on
Commit
48fb238
·
verified ·
1 Parent(s): 0813221

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -71
app.py CHANGED
@@ -2,12 +2,10 @@ import streamlit as st
2
  import numpy as np
3
  import pandas as pd
4
  import tensorflow as tf
5
- import wfdb
6
- import tempfile
7
- import os
8
- from scipy.signal import resample
9
  import matplotlib.pyplot as plt
10
- # Custom activation functions
 
 
11
  def sin_activation(x):
12
  return tf.math.sin(x)
13
 
@@ -18,7 +16,7 @@ def cos_activation(x):
18
  @st.cache_resource
19
  def load_model():
20
  return tf.keras.models.load_model(
21
- "model.keras", # Use the .keras format instead of .h5
22
  custom_objects={
23
  'sin': sin_activation,
24
  'cos': cos_activation,
@@ -28,8 +26,8 @@ def load_model():
28
 
29
  model = load_model()
30
 
31
- # AAMI class map
32
- class_map = {
33
  0: "Normal",
34
  1: "Supraventricular Ectopic (SVEB)",
35
  2: "Ventricular Ectopic (VEB)",
@@ -37,73 +35,73 @@ class_map = {
37
  4: "Unknown"
38
  }
39
 
40
- def extract_beats(record, annotation, window_size=257):
41
- beats = []
42
- r_locs = annotation.sample
43
- signal = record.p_signal[:, 0] # Using first channel
44
-
45
- for r in r_locs:
46
- start = max(0, r - window_size//2)
47
- end = min(len(signal), r + window_size//2 + 1)
48
 
49
- if end - start == window_size:
50
- beat = signal[start:end]
51
- beats.append(beat)
52
-
53
- return np.array(beats)
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- st.title("ECG Arrhythmia Classification")
56
- st.write("Upload MIT-BIH record files (.dat, .hea, .atr)")
 
57
 
58
- uploaded_files = st.file_uploader(
59
- "Choose files",
60
- type=["dat", "hea", "atr"],
61
- accept_multiple_files=True
62
  )
63
 
64
- if uploaded_files:
65
- with tempfile.TemporaryDirectory() as tmpdir:
66
- # Save uploaded files
67
- for f in uploaded_files:
68
- file_path = os.path.join(tmpdir, f.name)
69
- with open(file_path, "wb") as f_out:
70
- f_out.write(f.getbuffer())
71
-
72
- # Find base record name
73
- base_names = {os.path.splitext(f.name)[0] for f in uploaded_files}
74
- common_base = list(base_names)[0] # Get first base name
75
-
76
- try:
77
- # Read record
78
- record = wfdb.rdrecord(os.path.join(tmpdir, common_base))
79
- annotation = wfdb.rdann(os.path.join(tmpdir, common_base), 'atr')
80
 
81
- # Process beats
82
- beats = extract_beats(record, annotation)
83
- if len(beats) == 0:
84
- st.error("No valid beats found in the record")
85
- st.stop()
86
 
87
- # Preprocess and predict
88
- beats = beats.reshape((-1, 257, 1)).astype(np.float32)
89
- predictions = model.predict(beats)
90
- predicted_classes = np.argmax(predictions, axis=1)
91
-
92
- # Display results
93
- st.subheader("Classification Results")
94
- results = pd.DataFrame({
95
- "Beat Index": range(len(beats)),
96
- "Predicted Class": [class_map[c] for c in predicted_classes],
97
- "Confidence": np.max(predictions, axis=1)
98
- })
99
-
100
- st.dataframe(results)
101
-
102
- # Add visualization
103
- st.subheader("Sample ECG Beat")
104
- fig, ax = plt.subplots()
105
- ax.plot(beats[0].flatten())
106
- st.pyplot(fig)
107
-
108
- except Exception as e:
109
- st.error(f"Error processing files: {str(e)}")
 
 
 
2
  import numpy as np
3
  import pandas as pd
4
  import tensorflow as tf
 
 
 
 
5
  import matplotlib.pyplot as plt
6
+ import neurokit2 as nk
7
+
8
+ # Custom activation functions required by the model
9
  def sin_activation(x):
10
  return tf.math.sin(x)
11
 
 
16
  @st.cache_resource
17
  def load_model():
18
  return tf.keras.models.load_model(
19
+ "model.keras",
20
  custom_objects={
21
  'sin': sin_activation,
22
  'cos': cos_activation,
 
26
 
27
  model = load_model()
28
 
29
+ # AAMI class mapping matching training code
30
+ CLASS_MAP = {
31
  0: "Normal",
32
  1: "Supraventricular Ectopic (SVEB)",
33
  2: "Ventricular Ectopic (VEB)",
 
35
  4: "Unknown"
36
  }
37
 
38
+ def process_mitbih_file(dat_file):
39
+ """Process MIT-BIH .dat file using NeuroKit2"""
40
+ try:
41
+ # Read raw signal data (assuming format 16, 360Hz, gain=200)
42
+ signal = np.frombuffer(dat_file.getbuffer(), dtype=np.int16).astype(np.float32)
43
+ signal /= 200.0 # Convert to mV using standard gain
 
 
44
 
45
+ # NeuroKit2 processing with assumed 360Hz sampling rate
46
+ ecg_signals, info = nk.ecg_process(signal, sampling_rate=360)
47
+ r_peaks = info["ECG_R_Peaks"]
48
+
49
+ # Extract beats with same parameters as training
50
+ window_size = 257
51
+ beats = []
52
+ for r in r_peaks:
53
+ start = max(0, r - window_size//2)
54
+ end = start + window_size
55
+ if end <= len(signal):
56
+ beat = signal[start:end]
57
+ beats.append(beat)
58
+
59
+ return np.array(beats)
60
+ except Exception as e:
61
+ st.error(f"File processing error: {str(e)}")
62
+ return None
63
 
64
+ # Streamlit UI
65
+ st.title("ECG Arrhythmia Detection")
66
+ st.write("Upload MIT-BIH .dat file")
67
 
68
+ uploaded_file = st.file_uploader(
69
+ "Select .dat file",
70
+ type=["dat"],
71
+ accept_multiple_files=False
72
  )
73
 
74
+ if uploaded_file is not None:
75
+ if st.button("Analyze"):
76
+ with st.spinner("Processing ECG signal..."):
77
+ beats = process_mitbih_file(uploaded_file)
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ if beats is not None and len(beats) > 0:
80
+ # Prepare data for model
81
+ beats = beats.reshape((-1, 257, 1)).astype(np.float32)
 
 
82
 
83
+ # Make predictions
84
+ predictions = model.predict(beats)
85
+ pred_classes = np.argmax(predictions, axis=1)
86
+
87
+ # Show results
88
+ st.subheader("Analysis Results")
89
+ results = pd.DataFrame({
90
+ "Beat Index": range(len(beats)),
91
+ "Predicted Class": [CLASS_MAP[c] for c in pred_classes],
92
+ "Confidence": [f"{np.max(p):.1%}" for p in predictions]
93
+ })
94
+ st.dataframe(results)
95
+
96
+ # Visualizations
97
+ st.subheader("ECG Signal")
98
+ fig, ax = plt.subplots(1, 2, figsize=(15, 4))
99
+ ax[0].plot(beats[0].flatten())
100
+ ax[0].set_title("Sample ECG Beat")
101
+ class_dist = results["Predicted Class"].value_counts()
102
+ ax[1].bar(class_dist.index, class_dist.values)
103
+ ax[1].set_title("Class Distribution")
104
+ ax[1].tick_params(axis='x', rotation=45)
105
+ st.pyplot(fig)
106
+ else:
107
+ st.error("Failed to extract valid beats from the signal")