bakhili commited on
Commit
9b1ad10
Β·
verified Β·
1 Parent(s): 3130041

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +138 -77
src/streamlit_app.py CHANGED
@@ -267,8 +267,31 @@ def make_gradcam_heatmap(img_array, model, layer_name, pred_index=None):
267
  except Exception as e:
268
  return None, f"Grad-CAM computation error: {str(e)}"
269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  def create_real_gradcam_heatmap(img, model, predictions):
271
- """Create a real Grad-CAM heatmap with fallback strategies."""
272
  try:
273
  # Preprocess image
274
  img_resized = img.resize((224, 224))
@@ -285,7 +308,7 @@ def create_real_gradcam_heatmap(img, model, predictions):
285
  layer_name, layer_status = find_best_layer_for_gradcam(model)
286
 
287
  if layer_name is None:
288
- return None, f"❌ {layer_status}"
289
 
290
  # Generate Grad-CAM heatmap
291
  heatmap, error = make_gradcam_heatmap(
@@ -296,7 +319,7 @@ def create_real_gradcam_heatmap(img, model, predictions):
296
  )
297
 
298
  if error:
299
- return None, f"❌ {error} (Layer: {layer_name})"
300
 
301
  if heatmap is not None:
302
  # Resize heatmap to match input image size
@@ -308,12 +331,23 @@ def create_real_gradcam_heatmap(img, model, predictions):
308
  else:
309
  heatmap_resized = heatmap
310
 
311
- return heatmap_resized, f"βœ… Grad-CAM successful using layer: {layer_name}"
 
 
 
 
 
 
 
 
 
 
 
312
  else:
313
- return None, "❌ Failed to generate heatmap"
314
 
315
  except Exception as e:
316
- return None, f"❌ Grad-CAM error: {str(e)}"
317
 
318
  def predict_stroke(img, model):
319
  """Predict stroke type from image."""
@@ -356,12 +390,22 @@ def create_simulated_heatmap(img, predictions):
356
  mask = (x - center_x)**2 + (y - center_y)**2
357
  heatmap = np.exp(-mask / (2 * (50**2))) * confidence
358
 
359
- return heatmap, "⚠️ Using simulated heatmap (Grad-CAM failed)"
 
 
 
 
 
 
 
 
 
 
360
  except Exception as e:
361
- return None, f"❌ Simulated heatmap error: {str(e)}"
362
 
363
- def create_overlay_visualization(img, predictions, model, force_gradcam=True):
364
- """Create overlay visualization with debugging."""
365
  if not MPL_AVAILABLE:
366
  return None, "❌ Matplotlib not available"
367
 
@@ -372,52 +416,64 @@ def create_overlay_visualization(img, predictions, model, force_gradcam=True):
372
 
373
  heatmap = None
374
  status_message = ""
 
375
 
376
  # Try Grad-CAM first
377
  if force_gradcam and model is not None:
378
- heatmap, gradcam_status = create_real_gradcam_heatmap(img, model, predictions)
379
- status_message = gradcam_status
 
 
380
 
381
  # Fallback to simulated if Grad-CAM failed
382
  if heatmap is None:
383
- heatmap, sim_status = create_simulated_heatmap(img, predictions)
384
- if status_message:
385
- status_message += f" | {sim_status}"
386
- else:
387
- status_message = sim_status
 
 
388
 
389
  if heatmap is None:
390
  return None, "❌ Could not generate any heatmap"
391
 
392
- # Create visualization
393
- fig, ax = plt.subplots(figsize=(10, 8))
394
 
395
- # Show original image
396
- ax.imshow(img_array)
 
 
397
 
398
- # Overlay heatmap
399
- im = ax.imshow(heatmap, cmap='jet', alpha=0.4, interpolation='bilinear')
 
 
 
 
 
 
 
400
 
401
  # Determine title based on success
402
  if "βœ… Grad-CAM successful" in status_message:
403
- title = "🎯 Real AI Attention (Grad-CAM)"
404
  title_color = 'green'
405
  else:
406
- title = "🎨 Simulated Attention (Grad-CAM Failed)"
407
  title_color = 'orange'
408
 
409
- ax.set_title(title, fontsize=14, fontweight='bold', pad=20, color=title_color)
410
- ax.axis('off')
411
-
412
- # Add colorbar
413
- cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
414
- cbar.set_label('Attention Intensity', rotation=270, labelpad=20)
415
 
416
  plt.tight_layout()
417
- return fig, status_message
 
418
 
419
  except Exception as e:
420
- return None, f"❌ Visualization error: {str(e)}"
421
 
422
  # Main App
423
  def main():
@@ -479,13 +535,6 @@ def main():
479
  for layer in analysis['dense_layers']:
480
  st.write(f" β€’ {layer['name']} ({layer['type']}) - Shape: {layer['output_shape']}")
481
 
482
- if analysis['other_layers']:
483
- st.write("**βš™οΈ Other Layers:**")
484
- for layer in analysis['other_layers'][:5]: # Show first 5
485
- st.write(f" β€’ {layer['name']} ({layer['type']}) - Shape: {layer['output_shape']}")
486
- if len(analysis['other_layers']) > 5:
487
- st.write(f" ... and {len(analysis['other_layers']) - 5} more")
488
-
489
  # Test layer selection
490
  layer_name, layer_status = find_best_layer_for_gradcam(st.session_state.model)
491
  if "βœ…" in layer_status:
@@ -494,12 +543,6 @@ def main():
494
  st.markdown(f'<div class="status-box warning"><strong>Grad-CAM Layer:</strong> {layer_status}</div>', unsafe_allow_html=True)
495
  else:
496
  st.markdown(f'<div class="status-box error"><strong>Grad-CAM Layer:</strong> {layer_status}</div>', unsafe_allow_html=True)
497
-
498
- # Show model architecture recommendation
499
- if analysis['model_type'] == 'MLP (Multi-Layer Perceptron)':
500
- st.warning("⚠️ **Note:** Your model appears to be a Multi-Layer Perceptron (MLP) without convolutional layers. Grad-CAM works best with CNNs. The visualization may be limited or use experimental methods.")
501
- elif analysis['model_type'] == 'Custom Architecture':
502
- st.info("ℹ️ **Note:** Your model has a custom architecture. Grad-CAM compatibility depends on the specific layers used.")
503
 
504
  # Manual reload button
505
  if st.button("πŸ”„ Reload Model", help="Try to reload the model"):
@@ -524,22 +567,27 @@ def main():
524
  help="Try Grad-CAM even with non-CNN models (experimental)"
525
  )
526
 
 
 
 
 
 
 
 
527
  show_probabilities = st.checkbox("Show All Probabilities", value=True)
528
  show_debug = st.checkbox("Show Debug Info", value=True)
 
529
 
530
  st.markdown("---")
531
- st.header("ℹ️ About")
532
  st.info("""
533
- **Model Architecture:** Detected automatically
534
-
535
- **Classes:**
536
- - Hemorrhagic Stroke
537
- - Ischemic Stroke
538
- - No Stroke
539
-
540
- **Input:** 224Γ—224 RGB images
541
-
542
- **Attention Method:** Grad-CAM (when possible)
543
  """)
544
 
545
  if uploaded_file is not None:
@@ -547,7 +595,7 @@ def main():
547
  image = Image.open(uploaded_file)
548
 
549
  # Main content area
550
- col1, col2 = st.columns([1, 1])
551
 
552
  with col1:
553
  st.subheader("πŸ“‹ Classification Results")
@@ -582,20 +630,23 @@ def main():
582
  st.error("❌ Model not loaded. Check the debug information above to see available files.")
583
 
584
  with col2:
585
- st.subheader("🎯 AI Attention Visualization")
586
 
587
  if st.session_state.model is not None and 'predictions' in locals() and predictions is not None:
588
- # Create overlay visualization
589
- with st.spinner("🎨 Generating attention visualization..."):
590
- result = create_overlay_visualization(
591
  image,
592
  predictions,
593
  st.session_state.model,
594
- force_gradcam
 
595
  )
596
 
597
- if result and len(result) == 2:
598
- overlay_fig, status_message = result
 
 
599
  if overlay_fig is not None:
600
  st.pyplot(overlay_fig)
601
  plt.close()
@@ -608,6 +659,17 @@ def main():
608
  st.warning(f"⚠️ {status_message}")
609
  else:
610
  st.info(f"ℹ️ {status_message}")
 
 
 
 
 
 
 
 
 
 
 
611
  else:
612
  st.error(f"Could not generate visualization: {status_message}")
613
  else:
@@ -618,21 +680,20 @@ def main():
618
  else:
619
  # Welcome message
620
  st.markdown("""
621
- ## πŸ‘‹ Welcome to the Stroke Classification System
622
 
623
- This AI system analyzes brain scan images and attempts to show you where the AI focuses its attention.
624
 
625
- ### πŸš€ Features:
626
- - **Automatic Architecture Detection**: Identifies your model type
627
- - **Smart Layer Selection**: Finds the best layer for attention visualization
628
- - **Fallback Strategies**: Works with different model architectures
629
- - **Transparent Process**: Shows exactly what's happening
630
 
631
- ### πŸ“‹ How to Use:
632
- 1. **Check the Enhanced Model Architecture Analysis** above
633
- 2. **Upload a brain scan image** using the sidebar
634
- 3. **View classification results** with confidence scores
635
- 4. **Explore attention visualization** - real or simulated based on your model
636
 
637
  **Get started by uploading an image! πŸ‘ˆ**
638
  """)
 
267
  except Exception as e:
268
  return None, f"Grad-CAM computation error: {str(e)}"
269
 
270
+ def enhance_heatmap_contrast(heatmap):
271
+ """Enhance heatmap contrast and dynamic range."""
272
+ if heatmap is None:
273
+ return None
274
+
275
+ # Apply histogram equalization-like enhancement
276
+ heatmap_flat = heatmap.flatten()
277
+
278
+ # Remove zeros for better contrast
279
+ non_zero_values = heatmap_flat[heatmap_flat > 0]
280
+ if len(non_zero_values) == 0:
281
+ return heatmap
282
+
283
+ # Enhance contrast using percentile stretching
284
+ p2, p98 = np.percentile(non_zero_values, [2, 98])
285
+ heatmap_enhanced = np.clip((heatmap - p2) / (p98 - p2), 0, 1)
286
+
287
+ # Apply power law transformation for better visibility
288
+ gamma = 0.5 # Makes mid-tones brighter
289
+ heatmap_enhanced = np.power(heatmap_enhanced, gamma)
290
+
291
+ return heatmap_enhanced
292
+
293
  def create_real_gradcam_heatmap(img, model, predictions):
294
+ """Create a real Grad-CAM heatmap with enhanced visualization."""
295
  try:
296
  # Preprocess image
297
  img_resized = img.resize((224, 224))
 
308
  layer_name, layer_status = find_best_layer_for_gradcam(model)
309
 
310
  if layer_name is None:
311
+ return None, f"❌ {layer_status}", None
312
 
313
  # Generate Grad-CAM heatmap
314
  heatmap, error = make_gradcam_heatmap(
 
319
  )
320
 
321
  if error:
322
+ return None, f"❌ {error} (Layer: {layer_name})", None
323
 
324
  if heatmap is not None:
325
  # Resize heatmap to match input image size
 
331
  else:
332
  heatmap_resized = heatmap
333
 
334
+ # Enhance contrast
335
+ heatmap_enhanced = enhance_heatmap_contrast(heatmap_resized)
336
+
337
+ # Get statistics for debugging
338
+ stats = {
339
+ 'min': float(np.min(heatmap_enhanced)),
340
+ 'max': float(np.max(heatmap_enhanced)),
341
+ 'mean': float(np.mean(heatmap_enhanced)),
342
+ 'std': float(np.std(heatmap_enhanced))
343
+ }
344
+
345
+ return heatmap_enhanced, f"βœ… Grad-CAM successful using layer: {layer_name}", stats
346
  else:
347
+ return None, "❌ Failed to generate heatmap", None
348
 
349
  except Exception as e:
350
+ return None, f"❌ Grad-CAM error: {str(e)}", None
351
 
352
  def predict_stroke(img, model):
353
  """Predict stroke type from image."""
 
390
  mask = (x - center_x)**2 + (y - center_y)**2
391
  heatmap = np.exp(-mask / (2 * (50**2))) * confidence
392
 
393
+ # Enhance simulated heatmap too
394
+ heatmap_enhanced = enhance_heatmap_contrast(heatmap)
395
+
396
+ stats = {
397
+ 'min': float(np.min(heatmap_enhanced)),
398
+ 'max': float(np.max(heatmap_enhanced)),
399
+ 'mean': float(np.mean(heatmap_enhanced)),
400
+ 'std': float(np.std(heatmap_enhanced))
401
+ }
402
+
403
+ return heatmap_enhanced, "⚠️ Using simulated heatmap (Grad-CAM failed)", stats
404
  except Exception as e:
405
+ return None, f"❌ Simulated heatmap error: {str(e)}", None
406
 
407
+ def create_enhanced_overlay_visualization(img, predictions, model, force_gradcam=True, colormap='hot'):
408
+ """Create enhanced overlay visualization with better colors."""
409
  if not MPL_AVAILABLE:
410
  return None, "❌ Matplotlib not available"
411
 
 
416
 
417
  heatmap = None
418
  status_message = ""
419
+ stats = None
420
 
421
  # Try Grad-CAM first
422
  if force_gradcam and model is not None:
423
+ result = create_real_gradcam_heatmap(img, model, predictions)
424
+ if result and len(result) == 3:
425
+ heatmap, gradcam_status, stats = result
426
+ status_message = gradcam_status
427
 
428
  # Fallback to simulated if Grad-CAM failed
429
  if heatmap is None:
430
+ result = create_simulated_heatmap(img, predictions)
431
+ if result and len(result) == 3:
432
+ heatmap, sim_status, stats = result
433
+ if status_message:
434
+ status_message += f" | {sim_status}"
435
+ else:
436
+ status_message = sim_status
437
 
438
  if heatmap is None:
439
  return None, "❌ Could not generate any heatmap"
440
 
441
+ # Create enhanced visualization with multiple views
442
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
443
 
444
+ # 1. Original image
445
+ axes[0].imshow(img_array)
446
+ axes[0].set_title("Original Image", fontsize=12, fontweight='bold')
447
+ axes[0].axis('off')
448
 
449
+ # 2. Heatmap only with enhanced colors
450
+ im1 = axes[1].imshow(heatmap, cmap=colormap, vmin=0, vmax=1)
451
+ axes[1].set_title(f"Attention Heatmap ({colormap})", fontsize=12, fontweight='bold')
452
+ axes[1].axis('off')
453
+ plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
454
+
455
+ # 3. Overlay with better blending
456
+ axes[2].imshow(img_array)
457
+ im2 = axes[2].imshow(heatmap, cmap=colormap, alpha=0.6, vmin=0, vmax=1, interpolation='bilinear')
458
 
459
  # Determine title based on success
460
  if "βœ… Grad-CAM successful" in status_message:
461
+ title = "🎯 Real AI Attention Overlay"
462
  title_color = 'green'
463
  else:
464
+ title = "🎨 Simulated Attention Overlay"
465
  title_color = 'orange'
466
 
467
+ axes[2].set_title(title, fontsize=12, fontweight='bold', color=title_color)
468
+ axes[2].axis('off')
469
+ plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04)
 
 
 
470
 
471
  plt.tight_layout()
472
+
473
+ return fig, status_message, stats
474
 
475
  except Exception as e:
476
+ return None, f"❌ Visualization error: {str(e)}", None
477
 
478
  # Main App
479
  def main():
 
535
  for layer in analysis['dense_layers']:
536
  st.write(f" β€’ {layer['name']} ({layer['type']}) - Shape: {layer['output_shape']}")
537
 
 
 
 
 
 
 
 
538
  # Test layer selection
539
  layer_name, layer_status = find_best_layer_for_gradcam(st.session_state.model)
540
  if "βœ…" in layer_status:
 
543
  st.markdown(f'<div class="status-box warning"><strong>Grad-CAM Layer:</strong> {layer_status}</div>', unsafe_allow_html=True)
544
  else:
545
  st.markdown(f'<div class="status-box error"><strong>Grad-CAM Layer:</strong> {layer_status}</div>', unsafe_allow_html=True)
 
 
 
 
 
 
546
 
547
  # Manual reload button
548
  if st.button("πŸ”„ Reload Model", help="Try to reload the model"):
 
567
  help="Try Grad-CAM even with non-CNN models (experimental)"
568
  )
569
 
570
+ colormap = st.selectbox(
571
+ "Color Scheme",
572
+ ['hot', 'jet', 'viridis', 'plasma', 'inferno', 'magma', 'coolwarm'],
573
+ index=0,
574
+ help="Choose color scheme for heatmap visualization"
575
+ )
576
+
577
  show_probabilities = st.checkbox("Show All Probabilities", value=True)
578
  show_debug = st.checkbox("Show Debug Info", value=True)
579
+ show_stats = st.checkbox("Show Heatmap Statistics", value=True)
580
 
581
  st.markdown("---")
582
+ st.header("🎨 Color Scheme Guide")
583
  st.info("""
584
+ **hot**: Red-Yellow (classic heat)
585
+ **jet**: Blue-Green-Yellow-Red
586
+ **viridis**: Purple-Blue-Green-Yellow
587
+ **plasma**: Purple-Pink-Yellow
588
+ **inferno**: Black-Purple-Red-Yellow
589
+ **magma**: Black-Purple-Pink-White
590
+ **coolwarm**: Blue-White-Red
 
 
 
591
  """)
592
 
593
  if uploaded_file is not None:
 
595
  image = Image.open(uploaded_file)
596
 
597
  # Main content area
598
+ col1, col2 = st.columns([1, 2])
599
 
600
  with col1:
601
  st.subheader("πŸ“‹ Classification Results")
 
630
  st.error("❌ Model not loaded. Check the debug information above to see available files.")
631
 
632
  with col2:
633
+ st.subheader("🎯 Enhanced AI Attention Visualization")
634
 
635
  if st.session_state.model is not None and 'predictions' in locals() and predictions is not None:
636
+ # Create enhanced overlay visualization
637
+ with st.spinner("🎨 Generating enhanced attention visualization..."):
638
+ result = create_enhanced_overlay_visualization(
639
  image,
640
  predictions,
641
  st.session_state.model,
642
+ force_gradcam,
643
+ colormap
644
  )
645
 
646
+ if result and len(result) >= 2:
647
+ overlay_fig, status_message = result[0], result[1]
648
+ stats = result[2] if len(result) > 2 else None
649
+
650
  if overlay_fig is not None:
651
  st.pyplot(overlay_fig)
652
  plt.close()
 
659
  st.warning(f"⚠️ {status_message}")
660
  else:
661
  st.info(f"ℹ️ {status_message}")
662
+
663
+ # Show heatmap statistics
664
+ if show_stats and stats:
665
+ st.write("**πŸ“ˆ Heatmap Statistics:**")
666
+ col_stats1, col_stats2 = st.columns(2)
667
+ with col_stats1:
668
+ st.write(f"β€’ Min: {stats['min']:.3f}")
669
+ st.write(f"β€’ Max: {stats['max']:.3f}")
670
+ with col_stats2:
671
+ st.write(f"β€’ Mean: {stats['mean']:.3f}")
672
+ st.write(f"β€’ Std: {stats['std']:.3f}")
673
  else:
674
  st.error(f"Could not generate visualization: {status_message}")
675
  else:
 
680
  else:
681
  # Welcome message
682
  st.markdown("""
683
+ ## πŸ‘‹ Welcome to the Enhanced Stroke Classification System
684
 
685
+ This system now provides **better color visualization** and **enhanced contrast** for attention heatmaps.
686
 
687
+ ### 🎨 New Visualization Features:
688
+ - **Multiple Color Schemes**: Choose from 7 different color palettes
689
+ - **Enhanced Contrast**: Better visibility of attention patterns
690
+ - **Three-Panel View**: Original, heatmap, and overlay side-by-side
691
+ - **Statistics Display**: See heatmap value distributions
692
 
693
+ ### πŸš€ Why Colors Matter:
694
+ - **Red/Yellow (hot)**: High attention areas
695
+ - **Blue/Purple**: Low attention areas
696
+ - **Enhanced contrast**: Makes subtle patterns visible
 
697
 
698
  **Get started by uploading an image! πŸ‘ˆ**
699
  """)