lukiod commited on
Commit
fd3c0db
·
verified ·
1 Parent(s): 638a344

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -9
app.py CHANGED
@@ -26,7 +26,6 @@ def load_model():
26
  'gelu': tf.keras.activations.gelu
27
  }
28
  )
29
-
30
  model = load_model()
31
 
32
  # AAMI class map
@@ -42,20 +41,17 @@ def extract_beats(record, annotation, window_size=257):
42
  beats = []
43
  r_locs = annotation.sample
44
  signal = record.p_signal[:, 0] # Using first channel
45
-
46
  for r in r_locs:
47
  start = max(0, r - window_size//2)
48
  end = min(len(signal), r + window_size//2 + 1)
49
-
50
  if end - start == window_size:
51
  beat = signal[start:end]
52
  beats.append(beat)
53
-
54
  return np.array(beats)
55
 
56
  st.title("ECG Arrhythmia Classification")
57
  st.write("Upload MIT-BIH record files (.dat, .hea, .atr) or load record 108")
58
-
59
  record_loaded = False
60
  record = None
61
  annotation = None
@@ -83,10 +79,10 @@ if uploaded_files and not record_loaded:
83
  file_path = os.path.join(tmpdir, f.name)
84
  with open(file_path, "wb") as f_out:
85
  f_out.write(f.getbuffer())
86
-
87
  base_names = {os.path.splitext(f.name)[0] for f in uploaded_files}
88
  common_base = list(base_names)[0]
89
-
90
  try:
91
  record = wfdb.rdrecord(os.path.join(tmpdir, common_base))
92
  annotation = wfdb.rdann(os.path.join(tmpdir, common_base), 'atr')
@@ -100,7 +96,7 @@ if record_loaded and record is not None and annotation is not None:
100
  if len(beats) == 0:
101
  st.error("No valid beats found in the record")
102
  st.stop()
103
-
104
  beats = beats.reshape((-1, 257, 1)).astype(np.float32)
105
  predictions = model.predict(beats)
106
  predicted_classes = np.argmax(predictions, axis=1)
@@ -113,7 +109,29 @@ if record_loaded and record is not None and annotation is not None:
113
  })
114
  st.dataframe(results)
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  st.subheader("Sample ECG Beat")
117
  fig, ax = plt.subplots()
118
  ax.plot(beats[0].flatten())
119
- st.pyplot(fig)
 
26
  'gelu': tf.keras.activations.gelu
27
  }
28
  )
 
29
  model = load_model()
30
 
31
  # AAMI class map
 
41
  beats = []
42
  r_locs = annotation.sample
43
  signal = record.p_signal[:, 0] # Using first channel
44
+
45
  for r in r_locs:
46
  start = max(0, r - window_size//2)
47
  end = min(len(signal), r + window_size//2 + 1)
 
48
  if end - start == window_size:
49
  beat = signal[start:end]
50
  beats.append(beat)
 
51
  return np.array(beats)
52
 
53
  st.title("ECG Arrhythmia Classification")
54
  st.write("Upload MIT-BIH record files (.dat, .hea, .atr) or load record 108")
 
55
  record_loaded = False
56
  record = None
57
  annotation = None
 
79
  file_path = os.path.join(tmpdir, f.name)
80
  with open(file_path, "wb") as f_out:
81
  f_out.write(f.getbuffer())
82
+
83
  base_names = {os.path.splitext(f.name)[0] for f in uploaded_files}
84
  common_base = list(base_names)[0]
85
+
86
  try:
87
  record = wfdb.rdrecord(os.path.join(tmpdir, common_base))
88
  annotation = wfdb.rdann(os.path.join(tmpdir, common_base), 'atr')
 
96
  if len(beats) == 0:
97
  st.error("No valid beats found in the record")
98
  st.stop()
99
+
100
  beats = beats.reshape((-1, 257, 1)).astype(np.float32)
101
  predictions = model.predict(beats)
102
  predicted_classes = np.argmax(predictions, axis=1)
 
109
  })
110
  st.dataframe(results)
111
 
112
+ # Class Distribution Section
113
+ st.subheader("Class Distribution")
114
+
115
+ # Get counts for all classes
116
+ class_indices = list(class_map.keys())
117
+ class_names = [class_map[i] for i in class_indices]
118
+ counts = [np.sum(predicted_classes == i) for i in class_indices]
119
+
120
+ # Create distribution dataframe
121
+ distribution_df = pd.DataFrame({
122
+ "Class": class_names,
123
+ "Count": counts
124
+ })
125
+
126
+ # Display in two columns
127
+ col1, col2 = st.columns([1, 2])
128
+ with col1:
129
+ st.dataframe(distribution_df.style.format({'Count': '{:,}'}))
130
+
131
+ with col2:
132
+ st.bar_chart(distribution_df.set_index('Class'))
133
+
134
  st.subheader("Sample ECG Beat")
135
  fig, ax = plt.subplots()
136
  ax.plot(beats[0].flatten())
137
+ st.pyplot(fig)