lukiod commited on
Commit
5e1c4da
·
verified ·
1 Parent(s): 8ac6709

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -30
app.py CHANGED
@@ -178,39 +178,90 @@ if record_loaded and record is not None and annotation is not None:
178
  st.pyplot(fig)
179
 
180
  # ---------------- Grad-CAM Visualization Section ----------------
181
- st.subheader("Grad-CAM Heatmap Visualization for Each Beat")
182
- st.write("Below are Grad-CAM heatmaps overlaying each beat. The heatmaps show the regions contributing most to the predicted class.")
183
 
184
  # Automatically detect the last convolutional layer name
185
  conv_layer_name = get_last_conv_layer_name(model)
186
  if conv_layer_name is not None:
187
  st.write(f"Using Conv1D layer: **{conv_layer_name}** for Grad-CAM.")
188
-
189
- # Optionally, you can limit the number of beats displayed to avoid long processing times.
190
- # For demonstration, here we process all beats, but you might want to show only the first N beats.
191
- show_all = st.checkbox("Show Grad-CAM for all beats", value=False)
192
- if not show_all:
193
- num_beats_to_show = st.number_input("Number of beats to show:", min_value=1, max_value=len(beats), value=5)
194
- else:
195
- num_beats_to_show = len(beats)
196
-
197
- # Loop over each beat and its prediction to generate Grad-CAM heatmap
198
- for idx in range(num_beats_to_show):
199
- beat = beats[idx]
200
- pred_class = predicted_classes[idx]
201
- predicted_label = class_map[pred_class]
202
- # Compute Grad-CAM heatmap for the beat
203
- heatmap = make_gradcam_heatmap(beat, model, conv_layer_name, pred_class)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- # Generate visualization figure
206
- fig, ax = plt.subplots(figsize=(10, 3))
207
- # Plot the raw ECG beat
208
- ax.plot(beat.flatten(), color="black", label="ECG Beat")
209
- # Overlay Grad-CAM heatmap by scatter plotting points with a colormap according to heatmap value
210
- sc = ax.scatter(np.arange(len(beat)), beat.flatten(), c=heatmap, cmap="jet", s=25)
211
- ax.set_title(f"Beat {idx} - Predicted: {predicted_label}")
212
- ax.set_xlabel("Time Index")
213
- ax.set_ylabel("Amplitude")
214
- # Add a colorbar to indicate heatmap intensity
215
- fig.colorbar(sc, ax=ax, label="Grad-CAM Intensity")
216
- st.pyplot(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  st.pyplot(fig)
179
 
180
  # ---------------- Grad-CAM Visualization Section ----------------
181
+ st.subheader("Class Comparison with Grad-CAM")
182
+ st.write("Compare model explanations between classes present in this record")
183
 
184
  # Automatically detect the last convolutional layer name
185
  conv_layer_name = get_last_conv_layer_name(model)
186
  if conv_layer_name is not None:
187
  st.write(f"Using Conv1D layer: **{conv_layer_name}** for Grad-CAM.")
188
+
189
+ # Get classes actually present in the data
190
+ present_classes = distribution_df[distribution_df['Count'] > 0]['Class'].tolist()
191
+ if not present_classes:
192
+ st.warning("No classes with detected beats to compare")
193
+ st.stop()
194
+
195
+ # Class selection dropdowns
196
+ col1, col2, col3 = st.columns([1, 1, 1])
197
+ with col1:
198
+ left_class = st.selectbox(
199
+ "Left Class:",
200
+ options=present_classes,
201
+ index=0
202
+ )
203
+ with col2:
204
+ # Default to second class if available, else first
205
+ right_index = 1 if len(present_classes) > 1 else 0
206
+ right_class = st.selectbox(
207
+ "Right Class:",
208
+ options=present_classes,
209
+ index=right_index
210
+ )
211
+ with col3:
212
+ num_beats = st.number_input(
213
+ "Beats per class:",
214
+ min_value=1,
215
+ max_value=10,
216
+ value=3
217
+ )
218
+
219
+ # Get class indices from names
220
+ class_name_to_idx = {v: k for k, v in class_map.items()}
221
+ left_class_idx = class_name_to_idx[left_class]
222
+ right_class_idx = class_name_to_idx[right_class]
223
+ left_indices = np.where(predicted_classes == left_class_idx)[0]
224
+ right_indices = np.where(predicted_classes == right_class_idx)[0]
225
+
226
+ # Create comparison columns
227
+ left_col, right_col = st.columns(2)
228
+
229
+ def display_class_beats(col, class_name, beat_indices, num_beats):
230
+ """Helper function to display beats in a column"""
231
+ with col:
232
+ st.subheader(class_name)
233
+ if len(beat_indices) == 0:
234
+ st.warning(f"No {class_name} beats found")
235
+ return
236
 
237
+ num_to_show = min(num_beats, len(beat_indices))
238
+ for i, beat_idx in enumerate(beat_indices[:num_to_show]):
239
+ beat = beats[beat_idx]
240
+ pred_class = predicted_classes[beat_idx]
241
+
242
+ # Generate Grad-CAM heatmap
243
+ heatmap = make_gradcam_heatmap(beat, model, conv_layer_name, pred_class)
244
+
245
+ # Create visualization
246
+ fig, ax = plt.subplots(figsize=(8, 2))
247
+ ax.plot(beat.flatten(), color="black")
248
+ sc = ax.scatter(
249
+ np.arange(len(beat)),
250
+ beat.flatten(),
251
+ c=heatmap,
252
+ cmap="jet",
253
+ s=15
254
+ )
255
+ ax.set_title(f"Beat {beat_idx}")
256
+ plt.colorbar(sc, ax=ax)
257
+ st.pyplot(fig)
258
+
259
+ # Display left class beats
260
+ display_class_beats(left_col, left_class, left_indices, num_beats)
261
+
262
+ # Display right class beats
263
+ display_class_beats(right_col, right_class, right_indices, num_beats)
264
+
265
+ # Add comparison note if same class selected
266
+ if left_class == right_class:
267
+ st.info("Comparing different instances of the same class. Note: This shows intra-class variation.")