bakhili commited on
Commit
15bacdd
Β·
verified Β·
1 Parent(s): f26049e

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +100 -125
src/streamlit_app.py CHANGED
@@ -25,6 +25,14 @@ try:
25
  except ImportError:
26
  MPL_AVAILABLE = False
27
 
 
 
 
 
 
 
 
 
28
  # Page config
29
  st.set_page_config(
30
  page_title="Stroke Classifier",
@@ -56,6 +64,7 @@ st.markdown("""
56
  .success { background-color: #d4edda; border: 1px solid #c3e6cb; color: #155724; }
57
  .error { background-color: #f8d7da; border: 1px solid #f5c6cb; color: #721c24; }
58
  .info { background-color: #d1ecf1; border: 1px solid #bee5eb; color: #0c5460; }
 
59
  </style>""", unsafe_allow_html=True)
60
 
61
  # Initialize session state
@@ -142,16 +151,9 @@ def predict_stroke(img, model):
142
  except Exception as e:
143
  return None, f"Prediction error: {str(e)}"
144
 
145
- def create_overlay_heatmap(img, predictions):
146
- """Create an overlay heatmap on the original image."""
147
- if not MPL_AVAILABLE:
148
- return None
149
-
150
  try:
151
- # Resize image to 224x224 to match heatmap
152
- img_resized = img.resize((224, 224))
153
- img_array = np.array(img_resized)
154
-
155
  # Create a simple heatmap based on prediction confidence
156
  confidence = np.max(predictions)
157
 
@@ -170,36 +172,13 @@ def create_overlay_heatmap(img, predictions):
170
  mask = (x - center_x)**2 + (y - center_y)**2
171
  heatmap = np.exp(-mask / (2 * (50**2))) * confidence
172
 
173
- # Create the overlay visualization
174
- fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
175
-
176
- # Original image
177
- ax1.imshow(img_array)
178
- ax1.set_title("Original Image", fontsize=12, fontweight='bold')
179
- ax1.axis('off')
180
-
181
- # Heatmap only
182
- im2 = ax2.imshow(heatmap, cmap='jet', alpha=0.8)
183
- ax2.set_title("Attention Heatmap", fontsize=12, fontweight='bold')
184
- ax2.axis('off')
185
- plt.colorbar(im2, ax=ax2, fraction=0.046, pad=0.04)
186
-
187
- # Overlay - Original image with heatmap overlay
188
- ax3.imshow(img_array)
189
- im3 = ax3.imshow(heatmap, cmap='jet', alpha=0.4, interpolation='bilinear')
190
- ax3.set_title("Overlay Visualization", fontsize=12, fontweight='bold')
191
- ax3.axis('off')
192
- plt.colorbar(im3, ax=ax3, fraction=0.046, pad=0.04)
193
-
194
- plt.tight_layout()
195
- return fig
196
-
197
  except Exception as e:
198
- st.error(f"Heatmap generation error: {e}")
199
  return None
200
 
201
- def create_single_overlay(img, predictions):
202
- """Create a single overlay image combining original and heatmap."""
203
  if not MPL_AVAILABLE:
204
  return None
205
 
@@ -208,26 +187,25 @@ def create_single_overlay(img, predictions):
208
  img_resized = img.resize((224, 224))
209
  img_array = np.array(img_resized)
210
 
211
- # Create a simple heatmap based on prediction confidence
212
- confidence = np.max(predictions)
 
213
 
214
- # Generate random attention pattern weighted by confidence
215
- np.random.seed(42) # For reproducible results
216
- heatmap = np.random.rand(224, 224) * confidence
 
217
 
218
- # Add some structure to make it look more realistic
219
- try:
220
- from scipy import ndimage
221
- heatmap = ndimage.gaussian_filter(heatmap, sigma=20)
222
- except ImportError:
223
- # Fallback without scipy - create a simple gradient
224
- center_x, center_y = 112, 112
225
- y, x = np.ogrid[:224, :224]
226
- mask = (x - center_x)**2 + (y - center_y)**2
227
- heatmap = np.exp(-mask / (2 * (50**2))) * confidence
228
 
229
- # Create the single overlay visualization
230
- fig, ax = plt.subplots(figsize=(8, 8))
 
 
 
231
 
232
  # Show original image
233
  ax.imshow(img_array)
@@ -235,7 +213,7 @@ def create_single_overlay(img, predictions):
235
  # Overlay heatmap with transparency
236
  im = ax.imshow(heatmap, cmap='jet', alpha=0.4, interpolation='bilinear')
237
 
238
- ax.set_title("Brain Scan with AI Attention Overlay", fontsize=14, fontweight='bold', pad=20)
239
  ax.axis('off')
240
 
241
  # Add colorbar
@@ -243,34 +221,17 @@ def create_single_overlay(img, predictions):
243
  cbar.set_label('Attention Intensity', rotation=270, labelpad=20)
244
 
245
  plt.tight_layout()
246
- return fig
247
 
248
  except Exception as e:
249
  st.error(f"Overlay generation error: {e}")
250
- return None
251
 
252
  # Main App
253
  def main():
254
  # Header
255
  st.markdown('<h1 class="main-header">🧠 AI-Powered Stroke Classification System</h1>', unsafe_allow_html=True)
256
 
257
- # Debug info
258
- with st.expander("πŸ” Debug Information"):
259
- st.write(f"**Python Version:** {sys.version}")
260
- st.write(f"**Current Directory:** {os.getcwd()}")
261
- st.write(f"**Available Files:**")
262
-
263
- all_files = []
264
- for root, dirs, files in os.walk('.'):
265
- for file in files:
266
- all_files.append(os.path.join(root, file))
267
-
268
- for file in all_files[:20]: # Show first 20 files
269
- st.write(f" - {file}")
270
-
271
- if len(all_files) > 20:
272
- st.write(f" ... and {len(all_files) - 20} more files")
273
-
274
  # Auto-load model on startup
275
  if not st.session_state.model_loaded:
276
  with st.spinner("Loading AI model..."):
@@ -279,12 +240,11 @@ def main():
279
 
280
  # System status
281
  st.markdown("### πŸ”§ System Status")
282
- col1, col2, col3 = st.columns(3)
283
 
284
  with col1:
285
  if TF_AVAILABLE:
286
  st.markdown('<div class="status-box success">βœ… TensorFlow Ready</div>', unsafe_allow_html=True)
287
- st.write(f"TF Version: {tf.__version__}")
288
  else:
289
  st.markdown('<div class="status-box error">❌ TensorFlow Error</div>', unsafe_allow_html=True)
290
 
@@ -295,6 +255,12 @@ def main():
295
  st.markdown('<div class="status-box error">❌ Matplotlib Error</div>', unsafe_allow_html=True)
296
 
297
  with col3:
 
 
 
 
 
 
298
  if "βœ…" in st.session_state.model_status:
299
  st.markdown('<div class="status-box success">βœ… Model Loaded</div>', unsafe_allow_html=True)
300
  else:
@@ -303,6 +269,29 @@ def main():
303
  # Model status details
304
  st.markdown(f'<div class="status-box info"><strong>Model Status:</strong> {st.session_state.model_status}</div>', unsafe_allow_html=True)
305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  # Manual reload button
307
  if st.button("πŸ”„ Reload Model", help="Try to reload the model"):
308
  st.session_state.model_loaded = False
@@ -319,11 +308,12 @@ def main():
319
 
320
  st.markdown("---")
321
  st.header("🎨 Visualization Options")
322
- viz_option = st.radio(
323
- "Choose visualization style:",
324
- ["Single Overlay", "Side-by-Side Comparison", "Heatmap Only"],
325
- index=0,
326
- help="Select how you want to view the AI attention"
 
327
  )
328
 
329
  show_probabilities = st.checkbox("Show All Probabilities", value=True)
@@ -339,6 +329,8 @@ def main():
339
  - No Stroke
340
 
341
  **Input:** 224Γ—224 RGB images
 
 
342
  """)
343
 
344
  if uploaded_file is not None:
@@ -384,48 +376,32 @@ def main():
384
  st.subheader("🎯 AI Attention Visualization")
385
 
386
  if st.session_state.model is not None and 'predictions' in locals() and predictions is not None:
387
- if viz_option == "Single Overlay":
388
- # Create single overlay
389
- overlay_fig = create_single_overlay(image, predictions)
 
 
 
 
 
 
 
 
390
  if overlay_fig is not None:
391
  st.pyplot(overlay_fig)
392
  plt.close()
393
- else:
394
- st.error("Could not generate overlay visualization")
395
-
396
- elif viz_option == "Side-by-Side Comparison":
397
- # Create side-by-side comparison
398
- comparison_fig = create_overlay_heatmap(image, predictions)
399
- if comparison_fig is not None:
400
- st.pyplot(comparison_fig)
401
- plt.close()
402
- else:
403
- st.error("Could not generate comparison visualization")
404
-
405
- elif viz_option == "Heatmap Only":
406
- # Show just the heatmap
407
- if MPL_AVAILABLE:
408
- # Generate heatmap
409
- confidence = np.max(predictions)
410
- np.random.seed(42)
411
- heatmap = np.random.rand(224, 224) * confidence
412
-
413
- try:
414
- from scipy import ndimage
415
- heatmap = ndimage.gaussian_filter(heatmap, sigma=20)
416
- except ImportError:
417
- center_x, center_y = 112, 112
418
- y, x = np.ogrid[:224, :224]
419
- mask = (x - center_x)**2 + (y - center_y)**2
420
- heatmap = np.exp(-mask / (2 * (50**2))) * confidence
421
 
422
- fig, ax = plt.subplots(figsize=(6, 6))
423
- im = ax.imshow(heatmap, cmap='jet', alpha=0.8)
424
- ax.set_title("AI Attention Heatmap", fontweight='bold')
425
- ax.axis('off')
426
- plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
427
- st.pyplot(fig)
428
- plt.close()
 
 
 
 
429
  else:
430
  st.info("Upload an image and run classification to see AI attention visualization")
431
 
@@ -434,21 +410,20 @@ def main():
434
  st.markdown("""
435
  ## πŸ‘‹ Welcome to the Stroke Classification System
436
 
437
- This AI system analyzes brain scan images to detect stroke indicators.
438
 
439
  ### πŸš€ Features:
440
  - **Deep Learning Classification**: Advanced CNN architecture
441
- - **Visual Attention Maps**: See where the model focuses
442
  - **Three Classes**: Hemorrhagic Stroke, Ischemic Stroke, No Stroke
443
  - **Real-time Analysis**: Fast processing with confidence scores
444
- - **Multiple Visualizations**: Choose how to view AI attention
445
 
446
  ### πŸ“‹ How to Use:
447
- 1. **Check system status** above (should show green checkmarks)
448
  2. **Upload a brain scan image** using the sidebar
449
- 3. **Choose visualization style** (Single Overlay recommended)
450
- 4. **View classification results** with confidence scores
451
- 5. **Explore attention visualization** to understand the model's focus
452
 
453
  **Get started by uploading an image! πŸ‘ˆ**
454
  """)
 
25
  except ImportError:
26
  MPL_AVAILABLE = False
27
 
28
+ # Import our Grad-CAM utilities
29
+ try:
30
+ from gradcam_utils import create_real_attention_heatmap
31
+ GRADCAM_AVAILABLE = True
32
+ except ImportError:
33
+ GRADCAM_AVAILABLE = False
34
+ st.warning("Grad-CAM utilities not available - using simulated heatmaps")
35
+
36
  # Page config
37
  st.set_page_config(
38
  page_title="Stroke Classifier",
 
64
  .success { background-color: #d4edda; border: 1px solid #c3e6cb; color: #155724; }
65
  .error { background-color: #f8d7da; border: 1px solid #f5c6cb; color: #721c24; }
66
  .info { background-color: #d1ecf1; border: 1px solid #bee5eb; color: #0c5460; }
67
+ .warning { background-color: #fff3cd; border: 1px solid #ffeaa7; color: #856404; }
68
  </style>""", unsafe_allow_html=True)
69
 
70
  # Initialize session state
 
151
  except Exception as e:
152
  return None, f"Prediction error: {str(e)}"
153
 
154
+ def create_simulated_heatmap(img, predictions):
155
+ """Create a simulated heatmap (fallback when Grad-CAM is not available)."""
 
 
 
156
  try:
 
 
 
 
157
  # Create a simple heatmap based on prediction confidence
158
  confidence = np.max(predictions)
159
 
 
172
  mask = (x - center_x)**2 + (y - center_y)**2
173
  heatmap = np.exp(-mask / (2 * (50**2))) * confidence
174
 
175
+ return heatmap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  except Exception as e:
177
+ st.error(f"Simulated heatmap generation error: {e}")
178
  return None
179
 
180
+ def create_overlay_visualization(img, predictions, model, use_real_gradcam=True):
181
+ """Create overlay visualization with real or simulated heatmap."""
182
  if not MPL_AVAILABLE:
183
  return None
184
 
 
187
  img_resized = img.resize((224, 224))
188
  img_array = np.array(img_resized)
189
 
190
+ # Try to get real attention heatmap first
191
+ heatmap = None
192
+ heatmap_type = "Simulated"
193
 
194
+ if use_real_gradcam and GRADCAM_AVAILABLE and model is not None:
195
+ heatmap = create_real_attention_heatmap(img, model, predictions)
196
+ if heatmap is not None:
197
+ heatmap_type = "Real AI Attention (Grad-CAM)"
198
 
199
+ # Fallback to simulated heatmap
200
+ if heatmap is None:
201
+ heatmap = create_simulated_heatmap(img, predictions)
202
+ heatmap_type = "Simulated Attention"
 
 
 
 
 
 
203
 
204
+ if heatmap is None:
205
+ return None, "Could not generate heatmap"
206
+
207
+ # Create the overlay visualization
208
+ fig, ax = plt.subplots(figsize=(10, 8))
209
 
210
  # Show original image
211
  ax.imshow(img_array)
 
213
  # Overlay heatmap with transparency
214
  im = ax.imshow(heatmap, cmap='jet', alpha=0.4, interpolation='bilinear')
215
 
216
+ ax.set_title(f"Brain Scan with {heatmap_type}", fontsize=14, fontweight='bold', pad=20)
217
  ax.axis('off')
218
 
219
  # Add colorbar
 
221
  cbar.set_label('Attention Intensity', rotation=270, labelpad=20)
222
 
223
  plt.tight_layout()
224
+ return fig, heatmap_type
225
 
226
  except Exception as e:
227
  st.error(f"Overlay generation error: {e}")
228
+ return None, f"Error: {e}"
229
 
230
  # Main App
231
  def main():
232
  # Header
233
  st.markdown('<h1 class="main-header">🧠 AI-Powered Stroke Classification System</h1>', unsafe_allow_html=True)
234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  # Auto-load model on startup
236
  if not st.session_state.model_loaded:
237
  with st.spinner("Loading AI model..."):
 
240
 
241
  # System status
242
  st.markdown("### πŸ”§ System Status")
243
+ col1, col2, col3, col4 = st.columns(4)
244
 
245
  with col1:
246
  if TF_AVAILABLE:
247
  st.markdown('<div class="status-box success">βœ… TensorFlow Ready</div>', unsafe_allow_html=True)
 
248
  else:
249
  st.markdown('<div class="status-box error">❌ TensorFlow Error</div>', unsafe_allow_html=True)
250
 
 
255
  st.markdown('<div class="status-box error">❌ Matplotlib Error</div>', unsafe_allow_html=True)
256
 
257
  with col3:
258
+ if GRADCAM_AVAILABLE:
259
+ st.markdown('<div class="status-box success">βœ… Grad-CAM Ready</div>', unsafe_allow_html=True)
260
+ else:
261
+ st.markdown('<div class="status-box warning">⚠️ Grad-CAM Unavailable</div>', unsafe_allow_html=True)
262
+
263
+ with col4:
264
  if "βœ…" in st.session_state.model_status:
265
  st.markdown('<div class="status-box success">βœ… Model Loaded</div>', unsafe_allow_html=True)
266
  else:
 
269
  # Model status details
270
  st.markdown(f'<div class="status-box info"><strong>Model Status:</strong> {st.session_state.model_status}</div>', unsafe_allow_html=True)
271
 
272
+ # Explanation of heatmap types
273
+ with st.expander("πŸ” Understanding AI Attention Heatmaps"):
274
+ st.markdown("""
275
+ ### What do the heatmaps show?
276
+
277
+ **🎯 Real AI Attention (Grad-CAM):**
278
+ - Shows **actual** regions the AI model focuses on for its decision
279
+ - Uses gradient-weighted class activation mapping
280
+ - Highlights pixels that most influence the prediction
281
+ - **This is the AI's actual reasoning process**
282
+
283
+ **🎨 Simulated Attention:**
284
+ - Shows **fake** attention patterns (used when Grad-CAM unavailable)
285
+ - Based on random patterns scaled by prediction confidence
286
+ - **Does NOT represent actual AI reasoning**
287
+ - Used as a visual placeholder only
288
+
289
+ ### How to interpret:
290
+ - **Red/Yellow areas**: High attention (important for decision)
291
+ - **Blue/Purple areas**: Low attention (less important)
292
+ - **Intensity**: Stronger colors = more important regions
293
+ """)
294
+
295
  # Manual reload button
296
  if st.button("πŸ”„ Reload Model", help="Try to reload the model"):
297
  st.session_state.model_loaded = False
 
308
 
309
  st.markdown("---")
310
  st.header("🎨 Visualization Options")
311
+
312
+ use_real_gradcam = st.checkbox(
313
+ "Use Real AI Attention (Grad-CAM)",
314
+ value=GRADCAM_AVAILABLE,
315
+ disabled=not GRADCAM_AVAILABLE,
316
+ help="Show actual AI reasoning vs simulated patterns"
317
  )
318
 
319
  show_probabilities = st.checkbox("Show All Probabilities", value=True)
 
329
  - No Stroke
330
 
331
  **Input:** 224Γ—224 RGB images
332
+
333
+ **Attention Method:** Grad-CAM
334
  """)
335
 
336
  if uploaded_file is not None:
 
376
  st.subheader("🎯 AI Attention Visualization")
377
 
378
  if st.session_state.model is not None and 'predictions' in locals() and predictions is not None:
379
+ # Create overlay visualization
380
+ with st.spinner("🎨 Generating attention visualization..."):
381
+ result = create_overlay_visualization(
382
+ image,
383
+ predictions,
384
+ st.session_state.model,
385
+ use_real_gradcam
386
+ )
387
+
388
+ if result and len(result) == 2:
389
+ overlay_fig, heatmap_type = result
390
  if overlay_fig is not None:
391
  st.pyplot(overlay_fig)
392
  plt.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
 
394
+ # Show what type of heatmap is being displayed
395
+ if "Real AI Attention" in heatmap_type:
396
+ st.success(f"βœ… Showing: {heatmap_type}")
397
+ st.info("This heatmap shows the **actual** regions the AI focuses on for its decision.")
398
+ else:
399
+ st.warning(f"⚠️ Showing: {heatmap_type}")
400
+ st.info("This is a **simulated** heatmap and does NOT represent actual AI reasoning.")
401
+ else:
402
+ st.error("Could not generate visualization")
403
+ else:
404
+ st.error("Could not generate attention visualization")
405
  else:
406
  st.info("Upload an image and run classification to see AI attention visualization")
407
 
 
410
  st.markdown("""
411
  ## πŸ‘‹ Welcome to the Stroke Classification System
412
 
413
+ This AI system analyzes brain scan images to detect stroke indicators and shows you **exactly where the AI is looking**.
414
 
415
  ### πŸš€ Features:
416
  - **Deep Learning Classification**: Advanced CNN architecture
417
+ - **Real AI Attention Maps**: See actual model reasoning with Grad-CAM
418
  - **Three Classes**: Hemorrhagic Stroke, Ischemic Stroke, No Stroke
419
  - **Real-time Analysis**: Fast processing with confidence scores
420
+ - **Transparent AI**: Understand how the AI makes decisions
421
 
422
  ### πŸ“‹ How to Use:
423
+ 1. **Check system status** above (Grad-CAM should show βœ… for real attention)
424
  2. **Upload a brain scan image** using the sidebar
425
+ 3. **View classification results** with confidence scores
426
+ 4. **Explore REAL attention visualization** to see where the AI actually looks
 
427
 
428
  **Get started by uploading an image! πŸ‘ˆ**
429
  """)