bakhili commited on
Commit
5c6d9a9
Β·
verified Β·
1 Parent(s): 15bacdd

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +200 -88
src/streamlit_app.py CHANGED
@@ -25,14 +25,6 @@ try:
25
  except ImportError:
26
  MPL_AVAILABLE = False
27
 
28
- # Import our Grad-CAM utilities
29
- try:
30
- from gradcam_utils import create_real_attention_heatmap
31
- GRADCAM_AVAILABLE = True
32
- except ImportError:
33
- GRADCAM_AVAILABLE = False
34
- st.warning("Grad-CAM utilities not available - using simulated heatmaps")
35
-
36
  # Page config
37
  st.set_page_config(
38
  page_title="Stroke Classifier",
@@ -65,6 +57,7 @@ st.markdown("""
65
  .error { background-color: #f8d7da; border: 1px solid #f5c6cb; color: #721c24; }
66
  .info { background-color: #d1ecf1; border: 1px solid #bee5eb; color: #0c5460; }
67
  .warning { background-color: #fff3cd; border: 1px solid #ffeaa7; color: #856404; }
 
68
  </style>""", unsafe_allow_html=True)
69
 
70
  # Initialize session state
@@ -126,6 +119,126 @@ def load_stroke_model():
126
  except Exception as e:
127
  return None, f"❌ Model loading failed: {str(e)}"
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  def predict_stroke(img, model):
130
  """Predict stroke type from image."""
131
  if model is None:
@@ -152,68 +265,72 @@ def predict_stroke(img, model):
152
  return None, f"Prediction error: {str(e)}"
153
 
154
  def create_simulated_heatmap(img, predictions):
155
- """Create a simulated heatmap (fallback when Grad-CAM is not available)."""
156
  try:
157
- # Create a simple heatmap based on prediction confidence
158
  confidence = np.max(predictions)
159
-
160
- # Generate random attention pattern weighted by confidence
161
- np.random.seed(42) # For reproducible results
162
  heatmap = np.random.rand(224, 224) * confidence
163
 
164
- # Add some structure to make it look more realistic
165
  try:
166
  from scipy import ndimage
167
  heatmap = ndimage.gaussian_filter(heatmap, sigma=20)
168
  except ImportError:
169
- # Fallback without scipy - create a simple gradient
170
  center_x, center_y = 112, 112
171
  y, x = np.ogrid[:224, :224]
172
  mask = (x - center_x)**2 + (y - center_y)**2
173
  heatmap = np.exp(-mask / (2 * (50**2))) * confidence
174
 
175
- return heatmap
176
  except Exception as e:
177
- st.error(f"Simulated heatmap generation error: {e}")
178
- return None
179
 
180
- def create_overlay_visualization(img, predictions, model, use_real_gradcam=True):
181
- """Create overlay visualization with real or simulated heatmap."""
182
  if not MPL_AVAILABLE:
183
- return None
184
 
185
  try:
186
- # Resize image to 224x224 to match heatmap
187
  img_resized = img.resize((224, 224))
188
  img_array = np.array(img_resized)
189
 
190
- # Try to get real attention heatmap first
191
  heatmap = None
192
- heatmap_type = "Simulated"
193
 
194
- if use_real_gradcam and GRADCAM_AVAILABLE and model is not None:
195
- heatmap = create_real_attention_heatmap(img, model, predictions)
196
- if heatmap is not None:
197
- heatmap_type = "Real AI Attention (Grad-CAM)"
198
 
199
- # Fallback to simulated heatmap
200
  if heatmap is None:
201
- heatmap = create_simulated_heatmap(img, predictions)
202
- heatmap_type = "Simulated Attention"
 
 
 
203
 
204
  if heatmap is None:
205
- return None, "Could not generate heatmap"
206
 
207
- # Create the overlay visualization
208
  fig, ax = plt.subplots(figsize=(10, 8))
209
 
210
  # Show original image
211
  ax.imshow(img_array)
212
 
213
- # Overlay heatmap with transparency
214
  im = ax.imshow(heatmap, cmap='jet', alpha=0.4, interpolation='bilinear')
215
 
216
- ax.set_title(f"Brain Scan with {heatmap_type}", fontsize=14, fontweight='bold', pad=20)
 
 
 
 
 
 
 
 
217
  ax.axis('off')
218
 
219
  # Add colorbar
@@ -221,11 +338,10 @@ def create_overlay_visualization(img, predictions, model, use_real_gradcam=True)
221
  cbar.set_label('Attention Intensity', rotation=270, labelpad=20)
222
 
223
  plt.tight_layout()
224
- return fig, heatmap_type
225
 
226
  except Exception as e:
227
- st.error(f"Overlay generation error: {e}")
228
- return None, f"Error: {e}"
229
 
230
  # Main App
231
  def main():
@@ -240,11 +356,12 @@ def main():
240
 
241
  # System status
242
  st.markdown("### πŸ”§ System Status")
243
- col1, col2, col3, col4 = st.columns(4)
244
 
245
  with col1:
246
  if TF_AVAILABLE:
247
  st.markdown('<div class="status-box success">βœ… TensorFlow Ready</div>', unsafe_allow_html=True)
 
248
  else:
249
  st.markdown('<div class="status-box error">❌ TensorFlow Error</div>', unsafe_allow_html=True)
250
 
@@ -255,12 +372,6 @@ def main():
255
  st.markdown('<div class="status-box error">❌ Matplotlib Error</div>', unsafe_allow_html=True)
256
 
257
  with col3:
258
- if GRADCAM_AVAILABLE:
259
- st.markdown('<div class="status-box success">βœ… Grad-CAM Ready</div>', unsafe_allow_html=True)
260
- else:
261
- st.markdown('<div class="status-box warning">⚠️ Grad-CAM Unavailable</div>', unsafe_allow_html=True)
262
-
263
- with col4:
264
  if "βœ…" in st.session_state.model_status:
265
  st.markdown('<div class="status-box success">βœ… Model Loaded</div>', unsafe_allow_html=True)
266
  else:
@@ -269,28 +380,29 @@ def main():
269
  # Model status details
270
  st.markdown(f'<div class="status-box info"><strong>Model Status:</strong> {st.session_state.model_status}</div>', unsafe_allow_html=True)
271
 
272
- # Explanation of heatmap types
273
- with st.expander("πŸ” Understanding AI Attention Heatmaps"):
274
- st.markdown("""
275
- ### What do the heatmaps show?
276
-
277
- **🎯 Real AI Attention (Grad-CAM):**
278
- - Shows **actual** regions the AI model focuses on for its decision
279
- - Uses gradient-weighted class activation mapping
280
- - Highlights pixels that most influence the prediction
281
- - **This is the AI's actual reasoning process**
282
-
283
- **🎨 Simulated Attention:**
284
- - Shows **fake** attention patterns (used when Grad-CAM unavailable)
285
- - Based on random patterns scaled by prediction confidence
286
- - **Does NOT represent actual AI reasoning**
287
- - Used as a visual placeholder only
288
-
289
- ### How to interpret:
290
- - **Red/Yellow areas**: High attention (important for decision)
291
- - **Blue/Purple areas**: Low attention (less important)
292
- - **Intensity**: Stronger colors = more important regions
293
- """)
 
294
 
295
  # Manual reload button
296
  if st.button("πŸ”„ Reload Model", help="Try to reload the model"):
@@ -309,14 +421,14 @@ def main():
309
  st.markdown("---")
310
  st.header("🎨 Visualization Options")
311
 
312
- use_real_gradcam = st.checkbox(
313
- "Use Real AI Attention (Grad-CAM)",
314
- value=GRADCAM_AVAILABLE,
315
- disabled=not GRADCAM_AVAILABLE,
316
- help="Show actual AI reasoning vs simulated patterns"
317
  )
318
 
319
  show_probabilities = st.checkbox("Show All Probabilities", value=True)
 
320
 
321
  st.markdown("---")
322
  st.header("ℹ️ About")
@@ -382,24 +494,23 @@ def main():
382
  image,
383
  predictions,
384
  st.session_state.model,
385
- use_real_gradcam
386
  )
387
 
388
  if result and len(result) == 2:
389
- overlay_fig, heatmap_type = result
390
  if overlay_fig is not None:
391
  st.pyplot(overlay_fig)
392
  plt.close()
393
 
394
- # Show what type of heatmap is being displayed
395
- if "Real AI Attention" in heatmap_type:
396
- st.success(f"βœ… Showing: {heatmap_type}")
397
- st.info("This heatmap shows the **actual** regions the AI focuses on for its decision.")
398
- else:
399
- st.warning(f"⚠️ Showing: {heatmap_type}")
400
- st.info("This is a **simulated** heatmap and does NOT represent actual AI reasoning.")
401
  else:
402
- st.error("Could not generate visualization")
403
  else:
404
  st.error("Could not generate attention visualization")
405
  else:
@@ -410,20 +521,19 @@ def main():
410
  st.markdown("""
411
  ## πŸ‘‹ Welcome to the Stroke Classification System
412
 
413
- This AI system analyzes brain scan images to detect stroke indicators and shows you **exactly where the AI is looking**.
414
 
415
  ### πŸš€ Features:
416
  - **Deep Learning Classification**: Advanced CNN architecture
417
  - **Real AI Attention Maps**: See actual model reasoning with Grad-CAM
418
- - **Three Classes**: Hemorrhagic Stroke, Ischemic Stroke, No Stroke
419
- - **Real-time Analysis**: Fast processing with confidence scores
420
- - **Transparent AI**: Understand how the AI makes decisions
421
 
422
  ### πŸ“‹ How to Use:
423
- 1. **Check system status** above (Grad-CAM should show βœ… for real attention)
424
  2. **Upload a brain scan image** using the sidebar
425
  3. **View classification results** with confidence scores
426
- 4. **Explore REAL attention visualization** to see where the AI actually looks
427
 
428
  **Get started by uploading an image! πŸ‘ˆ**
429
  """)
@@ -434,3 +544,5 @@ def main():
434
 
435
  if __name__ == "__main__":
436
  main()
 
 
 
25
  except ImportError:
26
  MPL_AVAILABLE = False
27
 
 
 
 
 
 
 
 
 
28
  # Page config
29
  st.set_page_config(
30
  page_title="Stroke Classifier",
 
57
  .error { background-color: #f8d7da; border: 1px solid #f5c6cb; color: #721c24; }
58
  .info { background-color: #d1ecf1; border: 1px solid #bee5eb; color: #0c5460; }
59
  .warning { background-color: #fff3cd; border: 1px solid #ffeaa7; color: #856404; }
60
+ .debug { background-color: #f8f9fa; border: 1px solid #dee2e6; color: #495057; font-family: monospace; }
61
  </style>""", unsafe_allow_html=True)
62
 
63
  # Initialize session state
 
119
  except Exception as e:
120
  return None, f"❌ Model loading failed: {str(e)}"
121
 
122
+ def debug_model_layers(model):
123
+ """Debug function to show all model layers."""
124
+ if model is None:
125
+ return "No model loaded"
126
+
127
+ layer_info = []
128
+ conv_layers = []
129
+
130
+ for i, layer in enumerate(model.layers):
131
+ layer_type = type(layer).__name__
132
+ layer_info.append(f"{i}: {layer.name} ({layer_type})")
133
+
134
+ if 'conv' in layer.name.lower() or 'Conv' in layer_type:
135
+ conv_layers.append(f"βœ… {layer.name} ({layer_type})")
136
+
137
+ return {
138
+ 'all_layers': layer_info,
139
+ 'conv_layers': conv_layers,
140
+ 'total_layers': len(model.layers)
141
+ }
142
+
143
+ def find_best_conv_layer(model):
144
+ """Find the best convolutional layer for Grad-CAM."""
145
+ if model is None:
146
+ return None, "No model loaded"
147
+
148
+ conv_layers = []
149
+
150
+ # Look for convolutional layers
151
+ for layer in model.layers:
152
+ layer_type = type(layer).__name__
153
+ if any(conv_type in layer_type for conv_type in ['Conv2D', 'Conv', 'SeparableConv2D']):
154
+ conv_layers.append(layer.name)
155
+
156
+ if not conv_layers:
157
+ return None, "No convolutional layers found"
158
+
159
+ # Return the last convolutional layer
160
+ return conv_layers[-1], f"Found {len(conv_layers)} conv layers, using: {conv_layers[-1]}"
161
+
162
+ def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
163
+ """Generate Grad-CAM heatmap."""
164
+ try:
165
+ # Create a model that maps the input image to the activations of the last conv layer
166
+ grad_model = tf.keras.Model(
167
+ inputs=[model.inputs],
168
+ outputs=[model.get_layer(last_conv_layer_name).output, model.output]
169
+ )
170
+
171
+ # Compute the gradient of the top predicted class
172
+ with tf.GradientTape() as tape:
173
+ last_conv_layer_output, preds = grad_model(img_array)
174
+ if pred_index is None:
175
+ pred_index = tf.argmax(preds[0])
176
+ class_channel = preds[:, pred_index]
177
+
178
+ # Gradient of the output neuron with regard to the output feature map
179
+ grads = tape.gradient(class_channel, last_conv_layer_output)
180
+
181
+ # Mean intensity of the gradient over a specific feature map channel
182
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
183
+
184
+ # Multiply each channel by "how important this channel is"
185
+ last_conv_layer_output = last_conv_layer_output[0]
186
+ heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
187
+ heatmap = tf.squeeze(heatmap)
188
+
189
+ # Normalize the heatmap between 0 & 1
190
+ heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
191
+
192
+ return heatmap.numpy(), None
193
+
194
+ except Exception as e:
195
+ return None, f"Grad-CAM error: {str(e)}"
196
+
197
+ def create_real_gradcam_heatmap(img, model, predictions):
198
+ """Create a real Grad-CAM heatmap."""
199
+ try:
200
+ # Preprocess image
201
+ img_resized = img.resize((224, 224))
202
+ img_array = np.array(img_resized, dtype=np.float32)
203
+
204
+ # Handle grayscale
205
+ if len(img_array.shape) == 2:
206
+ img_array = np.stack([img_array] * 3, axis=-1)
207
+
208
+ # Normalize and add batch dimension
209
+ img_array = np.expand_dims(img_array, axis=0) / 255.0
210
+
211
+ # Find the best convolutional layer
212
+ conv_layer_name, layer_status = find_best_conv_layer(model)
213
+
214
+ if conv_layer_name is None:
215
+ return None, f"❌ {layer_status}"
216
+
217
+ # Generate Grad-CAM heatmap
218
+ heatmap, error = make_gradcam_heatmap(
219
+ img_array,
220
+ model,
221
+ conv_layer_name,
222
+ pred_index=np.argmax(predictions)
223
+ )
224
+
225
+ if error:
226
+ return None, f"❌ {error}"
227
+
228
+ if heatmap is not None:
229
+ # Resize heatmap to match input image size
230
+ heatmap_resized = tf.image.resize(
231
+ heatmap[..., tf.newaxis],
232
+ (224, 224)
233
+ ).numpy()[:, :, 0]
234
+
235
+ return heatmap_resized, f"βœ… Grad-CAM successful using layer: {conv_layer_name}"
236
+ else:
237
+ return None, "❌ Failed to generate heatmap"
238
+
239
+ except Exception as e:
240
+ return None, f"❌ Grad-CAM error: {str(e)}"
241
+
242
  def predict_stroke(img, model):
243
  """Predict stroke type from image."""
244
  if model is None:
 
265
  return None, f"Prediction error: {str(e)}"
266
 
267
  def create_simulated_heatmap(img, predictions):
268
+ """Create a simulated heatmap (fallback)."""
269
  try:
 
270
  confidence = np.max(predictions)
271
+ np.random.seed(42)
 
 
272
  heatmap = np.random.rand(224, 224) * confidence
273
 
 
274
  try:
275
  from scipy import ndimage
276
  heatmap = ndimage.gaussian_filter(heatmap, sigma=20)
277
  except ImportError:
 
278
  center_x, center_y = 112, 112
279
  y, x = np.ogrid[:224, :224]
280
  mask = (x - center_x)**2 + (y - center_y)**2
281
  heatmap = np.exp(-mask / (2 * (50**2))) * confidence
282
 
283
+ return heatmap, "⚠️ Using simulated heatmap (Grad-CAM failed)"
284
  except Exception as e:
285
+ return None, f"❌ Simulated heatmap error: {str(e)}"
 
286
 
287
+ def create_overlay_visualization(img, predictions, model, force_gradcam=True):
288
+ """Create overlay visualization with debugging."""
289
  if not MPL_AVAILABLE:
290
+ return None, "❌ Matplotlib not available"
291
 
292
  try:
293
+ # Resize image to 224x224
294
  img_resized = img.resize((224, 224))
295
  img_array = np.array(img_resized)
296
 
 
297
  heatmap = None
298
+ status_message = ""
299
 
300
+ # Try Grad-CAM first
301
+ if force_gradcam and model is not None:
302
+ heatmap, gradcam_status = create_real_gradcam_heatmap(img, model, predictions)
303
+ status_message = gradcam_status
304
 
305
+ # Fallback to simulated if Grad-CAM failed
306
  if heatmap is None:
307
+ heatmap, sim_status = create_simulated_heatmap(img, predictions)
308
+ if status_message:
309
+ status_message += f" | {sim_status}"
310
+ else:
311
+ status_message = sim_status
312
 
313
  if heatmap is None:
314
+ return None, "❌ Could not generate any heatmap"
315
 
316
+ # Create visualization
317
  fig, ax = plt.subplots(figsize=(10, 8))
318
 
319
  # Show original image
320
  ax.imshow(img_array)
321
 
322
+ # Overlay heatmap
323
  im = ax.imshow(heatmap, cmap='jet', alpha=0.4, interpolation='bilinear')
324
 
325
+ # Determine title based on success
326
+ if "βœ… Grad-CAM successful" in status_message:
327
+ title = "🎯 Real AI Attention (Grad-CAM)"
328
+ title_color = 'green'
329
+ else:
330
+ title = "🎨 Simulated Attention (Grad-CAM Failed)"
331
+ title_color = 'orange'
332
+
333
+ ax.set_title(title, fontsize=14, fontweight='bold', pad=20, color=title_color)
334
  ax.axis('off')
335
 
336
  # Add colorbar
 
338
  cbar.set_label('Attention Intensity', rotation=270, labelpad=20)
339
 
340
  plt.tight_layout()
341
+ return fig, status_message
342
 
343
  except Exception as e:
344
+ return None, f"❌ Visualization error: {str(e)}"
 
345
 
346
  # Main App
347
  def main():
 
356
 
357
  # System status
358
  st.markdown("### πŸ”§ System Status")
359
+ col1, col2, col3 = st.columns(3)
360
 
361
  with col1:
362
  if TF_AVAILABLE:
363
  st.markdown('<div class="status-box success">βœ… TensorFlow Ready</div>', unsafe_allow_html=True)
364
+ st.write(f"TF Version: {tf.__version__}")
365
  else:
366
  st.markdown('<div class="status-box error">❌ TensorFlow Error</div>', unsafe_allow_html=True)
367
 
 
372
  st.markdown('<div class="status-box error">❌ Matplotlib Error</div>', unsafe_allow_html=True)
373
 
374
  with col3:
 
 
 
 
 
 
375
  if "βœ…" in st.session_state.model_status:
376
  st.markdown('<div class="status-box success">βœ… Model Loaded</div>', unsafe_allow_html=True)
377
  else:
 
380
  # Model status details
381
  st.markdown(f'<div class="status-box info"><strong>Model Status:</strong> {st.session_state.model_status}</div>', unsafe_allow_html=True)
382
 
383
+ # Debug model architecture
384
+ if st.session_state.model is not None:
385
+ with st.expander("πŸ” Model Architecture Debug"):
386
+ debug_info = debug_model_layers(st.session_state.model)
387
+
388
+ st.write("**πŸ“Š Model Summary:**")
389
+ st.write(f"- Total layers: {debug_info['total_layers']}")
390
+ st.write(f"- Convolutional layers found: {len(debug_info['conv_layers'])}")
391
+
392
+ if debug_info['conv_layers']:
393
+ st.write("**🎯 Available Convolutional Layers:**")
394
+ for conv_layer in debug_info['conv_layers']:
395
+ st.write(f" {conv_layer}")
396
+
397
+ # Test which layer we'll use
398
+ conv_layer_name, layer_status = find_best_conv_layer(st.session_state.model)
399
+ st.markdown(f'<div class="status-box info"><strong>Selected for Grad-CAM:</strong> {layer_status}</div>', unsafe_allow_html=True)
400
+ else:
401
+ st.markdown('<div class="status-box error">❌ No convolutional layers found - Grad-CAM will not work</div>', unsafe_allow_html=True)
402
+
403
+ with st.expander("All Layers (Advanced)"):
404
+ for layer_info in debug_info['all_layers']:
405
+ st.code(layer_info)
406
 
407
  # Manual reload button
408
  if st.button("πŸ”„ Reload Model", help="Try to reload the model"):
 
421
  st.markdown("---")
422
  st.header("🎨 Visualization Options")
423
 
424
+ force_gradcam = st.checkbox(
425
+ "Force Grad-CAM Attempt",
426
+ value=True,
427
+ help="Always try Grad-CAM first (recommended)"
 
428
  )
429
 
430
  show_probabilities = st.checkbox("Show All Probabilities", value=True)
431
+ show_debug = st.checkbox("Show Debug Info", value=True)
432
 
433
  st.markdown("---")
434
  st.header("ℹ️ About")
 
494
  image,
495
  predictions,
496
  st.session_state.model,
497
+ force_gradcam
498
  )
499
 
500
  if result and len(result) == 2:
501
+ overlay_fig, status_message = result
502
  if overlay_fig is not None:
503
  st.pyplot(overlay_fig)
504
  plt.close()
505
 
506
+ # Show detailed status
507
+ if show_debug:
508
+ if "βœ… Grad-CAM successful" in status_message:
509
+ st.success(f"βœ… {status_message}")
510
+ else:
511
+ st.warning(f"⚠️ {status_message}")
 
512
  else:
513
+ st.error(f"Could not generate visualization: {status_message}")
514
  else:
515
  st.error("Could not generate attention visualization")
516
  else:
 
521
  st.markdown("""
522
  ## πŸ‘‹ Welcome to the Stroke Classification System
523
 
524
+ This AI system analyzes brain scan images and shows you **exactly where the AI is looking**.
525
 
526
  ### πŸš€ Features:
527
  - **Deep Learning Classification**: Advanced CNN architecture
528
  - **Real AI Attention Maps**: See actual model reasoning with Grad-CAM
529
+ - **Debug Information**: Understand why Grad-CAM works or fails
530
+ - **Transparent AI**: Full visibility into the decision process
 
531
 
532
  ### πŸ“‹ How to Use:
533
+ 1. **Check system status** and **model debug info** above
534
  2. **Upload a brain scan image** using the sidebar
535
  3. **View classification results** with confidence scores
536
+ 4. **Explore attention visualization** - it will tell you if it's real or simulated
537
 
538
  **Get started by uploading an image! πŸ‘ˆ**
539
  """)
 
544
 
545
  if __name__ == "__main__":
546
  main()
547
+
548
+ </merged_code>