Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -178,39 +178,90 @@ if record_loaded and record is not None and annotation is not None:
|
|
| 178 |
st.pyplot(fig)
|
| 179 |
|
| 180 |
# ---------------- Grad-CAM Visualization Section ----------------
|
| 181 |
-
st.subheader("Grad-CAM
|
| 182 |
-
st.write("
|
| 183 |
|
| 184 |
# Automatically detect the last convolutional layer name
|
| 185 |
conv_layer_name = get_last_conv_layer_name(model)
|
| 186 |
if conv_layer_name is not None:
|
| 187 |
st.write(f"Using Conv1D layer: **{conv_layer_name}** for Grad-CAM.")
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
st.pyplot(fig)
|
| 179 |
|
| 180 |
# ---------------- Grad-CAM Visualization Section ----------------
|
| 181 |
+
st.subheader("Class Comparison with Grad-CAM")
|
| 182 |
+
st.write("Compare model explanations between classes present in this record")
|
| 183 |
|
| 184 |
# Automatically detect the last convolutional layer name
|
| 185 |
conv_layer_name = get_last_conv_layer_name(model)
|
| 186 |
if conv_layer_name is not None:
|
| 187 |
st.write(f"Using Conv1D layer: **{conv_layer_name}** for Grad-CAM.")
|
| 188 |
+
|
| 189 |
+
# Get classes actually present in the data
|
| 190 |
+
present_classes = distribution_df[distribution_df['Count'] > 0]['Class'].tolist()
|
| 191 |
+
if not present_classes:
|
| 192 |
+
st.warning("No classes with detected beats to compare")
|
| 193 |
+
st.stop()
|
| 194 |
+
|
| 195 |
+
# Class selection dropdowns
|
| 196 |
+
col1, col2, col3 = st.columns([1, 1, 1])
|
| 197 |
+
with col1:
|
| 198 |
+
left_class = st.selectbox(
|
| 199 |
+
"Left Class:",
|
| 200 |
+
options=present_classes,
|
| 201 |
+
index=0
|
| 202 |
+
)
|
| 203 |
+
with col2:
|
| 204 |
+
# Default to second class if available, else first
|
| 205 |
+
right_index = 1 if len(present_classes) > 1 else 0
|
| 206 |
+
right_class = st.selectbox(
|
| 207 |
+
"Right Class:",
|
| 208 |
+
options=present_classes,
|
| 209 |
+
index=right_index
|
| 210 |
+
)
|
| 211 |
+
with col3:
|
| 212 |
+
num_beats = st.number_input(
|
| 213 |
+
"Beats per class:",
|
| 214 |
+
min_value=1,
|
| 215 |
+
max_value=10,
|
| 216 |
+
value=3
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# Get class indices from names
|
| 220 |
+
class_name_to_idx = {v: k for k, v in class_map.items()}
|
| 221 |
+
left_class_idx = class_name_to_idx[left_class]
|
| 222 |
+
right_class_idx = class_name_to_idx[right_class]
|
| 223 |
+
left_indices = np.where(predicted_classes == left_class_idx)[0]
|
| 224 |
+
right_indices = np.where(predicted_classes == right_class_idx)[0]
|
| 225 |
+
|
| 226 |
+
# Create comparison columns
|
| 227 |
+
left_col, right_col = st.columns(2)
|
| 228 |
+
|
| 229 |
+
def display_class_beats(col, class_name, beat_indices, num_beats):
|
| 230 |
+
"""Helper function to display beats in a column"""
|
| 231 |
+
with col:
|
| 232 |
+
st.subheader(class_name)
|
| 233 |
+
if len(beat_indices) == 0:
|
| 234 |
+
st.warning(f"No {class_name} beats found")
|
| 235 |
+
return
|
| 236 |
|
| 237 |
+
num_to_show = min(num_beats, len(beat_indices))
|
| 238 |
+
for i, beat_idx in enumerate(beat_indices[:num_to_show]):
|
| 239 |
+
beat = beats[beat_idx]
|
| 240 |
+
pred_class = predicted_classes[beat_idx]
|
| 241 |
+
|
| 242 |
+
# Generate Grad-CAM heatmap
|
| 243 |
+
heatmap = make_gradcam_heatmap(beat, model, conv_layer_name, pred_class)
|
| 244 |
+
|
| 245 |
+
# Create visualization
|
| 246 |
+
fig, ax = plt.subplots(figsize=(8, 2))
|
| 247 |
+
ax.plot(beat.flatten(), color="black")
|
| 248 |
+
sc = ax.scatter(
|
| 249 |
+
np.arange(len(beat)),
|
| 250 |
+
beat.flatten(),
|
| 251 |
+
c=heatmap,
|
| 252 |
+
cmap="jet",
|
| 253 |
+
s=15
|
| 254 |
+
)
|
| 255 |
+
ax.set_title(f"Beat {beat_idx}")
|
| 256 |
+
plt.colorbar(sc, ax=ax)
|
| 257 |
+
st.pyplot(fig)
|
| 258 |
+
|
| 259 |
+
# Display left class beats
|
| 260 |
+
display_class_beats(left_col, left_class, left_indices, num_beats)
|
| 261 |
+
|
| 262 |
+
# Display right class beats
|
| 263 |
+
display_class_beats(right_col, right_class, right_indices, num_beats)
|
| 264 |
+
|
| 265 |
+
# Add comparison note if same class selected
|
| 266 |
+
if left_class == right_class:
|
| 267 |
+
st.info("Comparing different instances of the same class. Note: This shows intra-class variation.")
|