lukiod commited on
Commit
0a632d3
·
verified ·
1 Parent(s): 5e1c4da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -47
app.py CHANGED
@@ -64,38 +64,45 @@ def get_last_conv_layer_name(model):
64
  return last_conv_layer
65
 
66
  # Function to generate Grad-CAM heatmap for a given beat and class index
67
- def make_gradcam_heatmap(beat, model, conv_layer_name, class_index):
68
- # Create a model that maps the input beat to the activations of the conv layer and the output predictions
 
 
 
 
 
 
69
  grad_model = tf.keras.models.Model(
70
- [model.inputs],
71
- [model.get_layer(conv_layer_name).output, model.output]
72
  )
73
- # Record operations for automatic differentiation
74
  with tf.GradientTape() as tape:
75
- # Expand dims to add batch axis: shape (1, 257, 1)
76
- beat_tensor = tf.expand_dims(beat, axis=0)
77
- conv_outputs, predictions = grad_model(beat_tensor)
78
- loss = predictions[:, class_index]
79
- # Compute gradients of the target class wrt feature map
 
80
  grads = tape.gradient(loss, conv_outputs)
81
- # Global average pooling over the time dimension to get weights
82
- weights = tf.reduce_mean(grads, axis=1)
83
- # Compute the weighted sum of feature maps along the channel dimension
84
- cam = tf.reduce_sum(tf.multiply(weights, conv_outputs), axis=-1)
85
- cam = tf.squeeze(cam) # Remove batch dimension
86
- # Apply ReLU to the heatmap to keep only positive influences
87
- heatmap = tf.maximum(cam, 0)
88
- # Normalize heatmap to the [0, 1] range
89
- heatmap_max = tf.reduce_max(heatmap)
90
- if heatmap_max == 0:
91
- heatmap = tf.zeros_like(heatmap)
92
- else:
93
- heatmap /= heatmap_max
94
- heatmap = heatmap.numpy()
95
- # Resize heatmap to match the input beat size (if needed)
96
- # For 1D, we use cv2.resize with the new shape (length, 1) then flatten
97
- heatmap = cv2.resize(heatmap, (beat.shape[0], 1)).flatten()
98
- return heatmap
99
 
100
  # Streamlit App Layout
101
  st.title("ECG Arrhythmia Classification with Grad-CAM Visualization")
@@ -227,36 +234,45 @@ if record_loaded and record is not None and annotation is not None:
227
  left_col, right_col = st.columns(2)
228
 
229
  def display_class_beats(col, class_name, beat_indices, num_beats):
230
- """Helper function to display beats in a column"""
231
  with col:
232
  st.subheader(class_name)
233
  if len(beat_indices) == 0:
234
  st.warning(f"No {class_name} beats found")
235
  return
236
-
237
- num_to_show = min(num_beats, len(beat_indices))
238
- for i, beat_idx in enumerate(beat_indices[:num_to_show]):
239
- beat = beats[beat_idx]
240
- pred_class = predicted_classes[beat_idx]
241
 
242
- # Generate Grad-CAM heatmap
243
- heatmap = make_gradcam_heatmap(beat, model, conv_layer_name, pred_class)
244
 
245
- # Create visualization
246
  fig, ax = plt.subplots(figsize=(8, 2))
247
- ax.plot(beat.flatten(), color="black")
248
- sc = ax.scatter(
249
- np.arange(len(beat)),
250
- beat.flatten(),
251
- c=heatmap,
252
- cmap="jet",
253
- s=15
 
 
254
  )
 
 
 
 
 
 
 
 
255
  ax.set_title(f"Beat {beat_idx}")
256
- plt.colorbar(sc, ax=ax)
 
 
257
  st.pyplot(fig)
258
-
259
- # Display left class beats
260
  display_class_beats(left_col, left_class, left_indices, num_beats)
261
 
262
  # Display right class beats
 
64
  return last_conv_layer
65
 
66
  # Function to generate Grad-CAM heatmap for a given beat and class index
67
+ def generate_grad_cam(model, sample, layer_name):
68
+ """
69
+ model : your loaded Keras model
70
+ sample : a 4D tensor of shape (1, window_size, 1)
71
+ layer_name : name of the Conv1D layer to use for Grad‑CAM
72
+ returns : 1D numpy heatmap of length window_size
73
+ """
74
+ # Build a model that returns both the conv outputs and the predictions
75
  grad_model = tf.keras.models.Model(
76
+ inputs=model.inputs,
77
+ outputs=[model.get_layer(layer_name).output, model.output]
78
  )
79
+
80
  with tf.GradientTape() as tape:
81
+ conv_outputs, predictions = grad_model(sample)
82
+ # pick the top predicted class
83
+ class_idx = tf.argmax(predictions[0])
84
+ loss = predictions[:, class_idx]
85
+
86
+ # gradient of the loss wrt conv outputs
87
  grads = tape.gradient(loss, conv_outputs)
88
+
89
+ # global average pool the gradients to get the importance of each channel
90
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1)) # shape = (channels,)
91
+
92
+ # remove batch dim from conv_outputs -> (time, channels)
93
+ conv_outputs = tf.squeeze(conv_outputs, axis=0)
94
+
95
+ # weight the conv outputs by the pooled gradients
96
+ cam = tf.reduce_sum(conv_outputs * pooled_grads, axis=-1) # shape = (time,)
97
+ raw = cam.numpy()
98
+ print("raw min/max:", raw.min(), raw.max())
99
+
100
+ cam = tf.abs(cam) # ReLU
101
+ cam = cam / (tf.reduce_max(cam) + 1e-8) # normalize
102
+
103
+ return cam.numpy()
104
+
105
+
106
 
107
  # Streamlit App Layout
108
  st.title("ECG Arrhythmia Classification with Grad-CAM Visualization")
 
234
  left_col, right_col = st.columns(2)
235
 
236
  def display_class_beats(col, class_name, beat_indices, num_beats):
 
237
  with col:
238
  st.subheader(class_name)
239
  if len(beat_indices) == 0:
240
  st.warning(f"No {class_name} beats found")
241
  return
242
+
243
+ for beat_idx in beat_indices[:num_beats]:
244
+ beat = beats[beat_idx].flatten() # shape (window_size,)
245
+ sample = beat.reshape(1, -1, 1).astype(np.float32)
 
246
 
247
+ # generate the 1D heatmap
248
+ heatmap = generate_grad_cam(model, sample, conv_layer_name)
249
 
250
+ # set up figure
251
  fig, ax = plt.subplots(figsize=(8, 2))
252
+ y_min, y_max = beat.min(), beat.max()
253
+
254
+ # Always draw the heatmap background for all beats
255
+ ax.imshow(
256
+ np.expand_dims(heatmap, axis=0), # shape (1, window_size)
257
+ aspect='auto',
258
+ cmap='jet',
259
+ alpha=0.5,
260
+ extent=[0, len(beat), y_min, y_max]
261
  )
262
+
263
+ # overlay the ECG trace
264
+ ax.plot(beat, linewidth=2, color='blue')
265
+
266
+ # styling
267
+ # Do NOT set a facecolor here - it will block the heatmap
268
+ # ax.set_facecolor('#e0e0f0') # This line is commented out
269
+ ax.axis('off') # clean look
270
  ax.set_title(f"Beat {beat_idx}")
271
+ ax.set_xlim(0, len(beat))
272
+ ax.set_ylim(y_min, y_max)
273
+
274
  st.pyplot(fig)
275
+ # Display left class beats
 
276
  display_class_beats(left_col, left_class, left_indices, num_beats)
277
 
278
  # Display right class beats