bakhili commited on
Commit
93c1900
Β·
verified Β·
1 Parent(s): 9b1ad10

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +247 -228
src/streamlit_app.py CHANGED
@@ -129,40 +129,44 @@ def analyze_model_architecture(model):
129
  'conv_layers': [],
130
  'dense_layers': [],
131
  'other_layers': [],
132
- 'potential_gradcam_layers': [],
133
  'model_type': 'Unknown'
134
  }
135
 
136
  for i, layer in enumerate(model.layers):
137
  layer_type = type(layer).__name__
 
 
138
  layer_info = {
139
  'index': i,
140
  'name': layer.name,
141
  'type': layer_type,
142
- 'output_shape': getattr(layer, 'output_shape', 'Unknown')
 
 
143
  }
144
 
145
- # Categorize layers
 
 
 
 
 
 
 
 
 
146
  if any(conv_type in layer_type for conv_type in [
147
  'Conv1D', 'Conv2D', 'Conv3D', 'SeparableConv2D', 'DepthwiseConv2D',
148
  'Convolution1D', 'Convolution2D', 'Convolution3D'
149
- ]):
150
  layer_analysis['conv_layers'].append(layer_info)
151
- layer_analysis['potential_gradcam_layers'].append(layer_info)
152
 
153
  elif 'Dense' in layer_type or 'Linear' in layer_type:
154
  layer_analysis['dense_layers'].append(layer_info)
155
 
156
- # Check for other layer types that might work with Grad-CAM
157
- elif any(layer_name in layer_type for layer_name in [
158
- 'Activation', 'BatchNormalization', 'Dropout', 'MaxPooling', 'AveragePooling',
159
- 'GlobalMaxPooling', 'GlobalAveragePooling', 'Flatten', 'Reshape'
160
- ]):
161
  layer_analysis['other_layers'].append(layer_info)
162
-
163
- # Some of these might be suitable for Grad-CAM if they have spatial dimensions
164
- if any(pool_type in layer_type for pool_type in ['MaxPooling2D', 'AveragePooling2D']):
165
- layer_analysis['potential_gradcam_layers'].append(layer_info)
166
 
167
  # Determine model type
168
  if layer_analysis['conv_layers']:
@@ -174,124 +178,112 @@ def analyze_model_architecture(model):
174
 
175
  return layer_analysis
176
 
177
- def find_best_layer_for_gradcam(model):
178
- """Find the best layer for Grad-CAM with expanded search."""
179
- if model is None:
180
- return None, "No model loaded"
181
-
182
- analysis = analyze_model_architecture(model)
183
-
184
- # Priority 1: Convolutional layers (best for Grad-CAM)
185
- if analysis['conv_layers']:
186
- best_layer = analysis['conv_layers'][-1] # Last conv layer
187
- return best_layer['name'], f"βœ… Using convolutional layer: {best_layer['name']} ({best_layer['type']})"
188
-
189
- # Priority 2: Pooling layers with spatial dimensions
190
- pooling_layers = [layer for layer in analysis['other_layers']
191
- if any(pool_type in layer['type'] for pool_type in ['MaxPooling2D', 'AveragePooling2D'])]
192
- if pooling_layers:
193
- best_layer = pooling_layers[-1]
194
- return best_layer['name'], f"⚠️ Using pooling layer: {best_layer['name']} ({best_layer['type']}) - may not work well"
195
-
196
- # Priority 3: Try any layer with 4D output (batch, height, width, channels)
197
- for layer in reversed(model.layers):
198
- try:
199
- output_shape = layer.output_shape
200
- if isinstance(output_shape, tuple) and len(output_shape) == 4:
201
- return layer.name, f"⚠️ Trying 4D layer: {layer.name} ({type(layer).__name__}) - experimental"
202
- except:
203
- continue
204
-
205
- # Priority 4: Try the layer before the last dense layer
206
- if analysis['dense_layers'] and len(model.layers) > 2:
207
- # Find the layer just before the first dense layer
208
- for i, layer in enumerate(model.layers):
209
- if 'Dense' in type(layer).__name__:
210
- if i > 0:
211
- prev_layer = model.layers[i-1]
212
- return prev_layer.name, f"⚠️ Using pre-dense layer: {prev_layer.name} ({type(prev_layer).__name__}) - may not work"
213
- break
214
 
215
- return None, "❌ No suitable layers found for Grad-CAM"
216
-
217
- def make_gradcam_heatmap(img_array, model, layer_name, pred_index=None):
218
- """Generate Grad-CAM heatmap with better error handling."""
219
  try:
220
- # Get the target layer
221
  target_layer = model.get_layer(layer_name)
 
222
 
223
- # Create a model that maps the input image to the activations of the target layer
224
  grad_model = tf.keras.Model(
225
  inputs=[model.inputs],
226
  outputs=[target_layer.output, model.output]
227
  )
228
 
229
- # Compute the gradient of the top predicted class
230
  with tf.GradientTape() as tape:
231
  layer_output, preds = grad_model(img_array)
 
 
 
232
  if pred_index is None:
233
  pred_index = tf.argmax(preds[0])
 
 
 
234
  class_channel = preds[:, pred_index]
 
235
 
236
- # Get gradients
237
  grads = tape.gradient(class_channel, layer_output)
238
 
239
  if grads is None:
240
- return None, "Gradients are None - layer may not be differentiable"
 
 
 
 
 
 
 
 
 
241
 
242
- # Handle different layer output shapes
243
- if len(layer_output.shape) == 4: # Standard conv layer (batch, height, width, channels)
244
- # Standard Grad-CAM for conv layers
 
245
  pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
246
  layer_output = layer_output[0]
247
  heatmap = layer_output @ pooled_grads[..., tf.newaxis]
248
  heatmap = tf.squeeze(heatmap)
249
-
250
- elif len(layer_output.shape) == 2: # Dense layer (batch, features)
251
- # For dense layers, create a simple attention based on gradients
252
- grads_abs = tf.abs(grads[0])
253
- heatmap = tf.reduce_mean(grads_abs)
254
- # Create a uniform heatmap since dense layers don't have spatial structure
255
- heatmap = tf.ones((7, 7)) * heatmap # Small heatmap to be resized later
256
-
257
  else:
258
- return None, f"Unsupported layer output shape: {layer_output.shape}"
 
 
 
 
 
 
 
 
 
259
 
260
- # Normalize the heatmap
261
  heatmap = tf.maximum(heatmap, 0)
262
- if tf.reduce_max(heatmap) > 0:
263
- heatmap = heatmap / tf.reduce_max(heatmap)
264
 
265
- return heatmap.numpy(), None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
  except Exception as e:
268
- return None, f"Grad-CAM computation error: {str(e)}"
269
-
270
- def enhance_heatmap_contrast(heatmap):
271
- """Enhance heatmap contrast and dynamic range."""
272
- if heatmap is None:
273
- return None
274
-
275
- # Apply histogram equalization-like enhancement
276
- heatmap_flat = heatmap.flatten()
277
-
278
- # Remove zeros for better contrast
279
- non_zero_values = heatmap_flat[heatmap_flat > 0]
280
- if len(non_zero_values) == 0:
281
- return heatmap
282
-
283
- # Enhance contrast using percentile stretching
284
- p2, p98 = np.percentile(non_zero_values, [2, 98])
285
- heatmap_enhanced = np.clip((heatmap - p2) / (p98 - p2), 0, 1)
286
-
287
- # Apply power law transformation for better visibility
288
- gamma = 0.5 # Makes mid-tones brighter
289
- heatmap_enhanced = np.power(heatmap_enhanced, gamma)
290
-
291
- return heatmap_enhanced
292
 
293
- def create_real_gradcam_heatmap(img, model, predictions):
294
- """Create a real Grad-CAM heatmap with enhanced visualization."""
295
  try:
296
  # Preprocess image
297
  img_resized = img.resize((224, 224))
@@ -304,50 +296,67 @@ def create_real_gradcam_heatmap(img, model, predictions):
304
  # Normalize and add batch dimension
305
  img_array = np.expand_dims(img_array, axis=0) / 255.0
306
 
307
- # Find the best layer for Grad-CAM
308
- layer_name, layer_status = find_best_layer_for_gradcam(model)
309
 
310
- if layer_name is None:
311
- return None, f"❌ {layer_status}", None
312
 
313
- # Generate Grad-CAM heatmap
314
- heatmap, error = make_gradcam_heatmap(
315
- img_array,
316
- model,
317
- layer_name,
318
- pred_index=np.argmax(predictions)
319
- )
320
 
321
- if error:
322
- return None, f"❌ {error} (Layer: {layer_name})", None
 
 
 
 
323
 
324
- if heatmap is not None:
325
- # Resize heatmap to match input image size
326
- if heatmap.shape[0] < 224 or heatmap.shape[1] < 224:
327
- heatmap_resized = tf.image.resize(
328
- heatmap[..., tf.newaxis],
329
- (224, 224)
330
- ).numpy()[:, :, 0]
331
- else:
332
- heatmap_resized = heatmap
333
-
334
- # Enhance contrast
335
- heatmap_enhanced = enhance_heatmap_contrast(heatmap_resized)
336
-
337
- # Get statistics for debugging
338
- stats = {
339
- 'min': float(np.min(heatmap_enhanced)),
340
- 'max': float(np.max(heatmap_enhanced)),
341
- 'mean': float(np.mean(heatmap_enhanced)),
342
- 'std': float(np.std(heatmap_enhanced))
343
- }
344
 
345
- return heatmap_enhanced, f"βœ… Grad-CAM successful using layer: {layer_name}", stats
346
- else:
347
- return None, "❌ Failed to generate heatmap", None
348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  except Exception as e:
350
- return None, f"❌ Grad-CAM error: {str(e)}", None
351
 
352
  def predict_stroke(img, model):
353
  """Predict stroke type from image."""
@@ -374,38 +383,58 @@ def predict_stroke(img, model):
374
  except Exception as e:
375
  return None, f"Prediction error: {str(e)}"
376
 
377
- def create_simulated_heatmap(img, predictions):
378
- """Create a simulated heatmap (fallback)."""
379
  try:
380
  confidence = np.max(predictions)
381
- np.random.seed(42)
382
- heatmap = np.random.rand(224, 224) * confidence
383
-
384
- try:
385
- from scipy import ndimage
386
- heatmap = ndimage.gaussian_filter(heatmap, sigma=20)
387
- except ImportError:
 
 
 
 
388
  center_x, center_y = 112, 112
389
- y, x = np.ogrid[:224, :224]
390
- mask = (x - center_x)**2 + (y - center_y)**2
391
- heatmap = np.exp(-mask / (2 * (50**2))) * confidence
392
 
393
- # Enhance simulated heatmap too
394
- heatmap_enhanced = enhance_heatmap_contrast(heatmap)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
 
396
  stats = {
397
- 'min': float(np.min(heatmap_enhanced)),
398
- 'max': float(np.max(heatmap_enhanced)),
399
- 'mean': float(np.mean(heatmap_enhanced)),
400
- 'std': float(np.std(heatmap_enhanced))
401
  }
402
 
403
- return heatmap_enhanced, "⚠️ Using simulated heatmap (Grad-CAM failed)", stats
404
  except Exception as e:
405
  return None, f"❌ Simulated heatmap error: {str(e)}", None
406
 
407
- def create_enhanced_overlay_visualization(img, predictions, model, force_gradcam=True, colormap='hot'):
408
- """Create enhanced overlay visualization with better colors."""
409
  if not MPL_AVAILABLE:
410
  return None, "❌ Matplotlib not available"
411
 
@@ -417,17 +446,20 @@ def create_enhanced_overlay_visualization(img, predictions, model, force_gradcam
417
  heatmap = None
418
  status_message = ""
419
  stats = None
 
420
 
421
  # Try Grad-CAM first
422
  if force_gradcam and model is not None:
423
- result = create_real_gradcam_heatmap(img, model, predictions)
424
- if result and len(result) == 3:
425
- heatmap, gradcam_status, stats = result
 
 
426
  status_message = gradcam_status
427
 
428
- # Fallback to simulated if Grad-CAM failed
429
  if heatmap is None:
430
- result = create_simulated_heatmap(img, predictions)
431
  if result and len(result) == 3:
432
  heatmap, sim_status, stats = result
433
  if status_message:
@@ -436,9 +468,9 @@ def create_enhanced_overlay_visualization(img, predictions, model, force_gradcam
436
  status_message = sim_status
437
 
438
  if heatmap is None:
439
- return None, "❌ Could not generate any heatmap"
440
 
441
- # Create enhanced visualization with multiple views
442
  fig, axes = plt.subplots(1, 3, figsize=(15, 5))
443
 
444
  # 1. Original image
@@ -446,13 +478,13 @@ def create_enhanced_overlay_visualization(img, predictions, model, force_gradcam
446
  axes[0].set_title("Original Image", fontsize=12, fontweight='bold')
447
  axes[0].axis('off')
448
 
449
- # 2. Heatmap only with enhanced colors
450
  im1 = axes[1].imshow(heatmap, cmap=colormap, vmin=0, vmax=1)
451
  axes[1].set_title(f"Attention Heatmap ({colormap})", fontsize=12, fontweight='bold')
452
  axes[1].axis('off')
453
  plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
454
 
455
- # 3. Overlay with better blending
456
  axes[2].imshow(img_array)
457
  im2 = axes[2].imshow(heatmap, cmap=colormap, alpha=0.6, vmin=0, vmax=1, interpolation='bilinear')
458
 
@@ -470,10 +502,10 @@ def create_enhanced_overlay_visualization(img, predictions, model, force_gradcam
470
 
471
  plt.tight_layout()
472
 
473
- return fig, status_message, stats
474
 
475
  except Exception as e:
476
- return None, f"❌ Visualization error: {str(e)}", None
477
 
478
  # Main App
479
  def main():
@@ -514,7 +546,7 @@ def main():
514
 
515
  # Enhanced model architecture analysis
516
  if st.session_state.model is not None:
517
- with st.expander("πŸ” Enhanced Model Architecture Analysis"):
518
  analysis = analyze_model_architecture(st.session_state.model)
519
 
520
  st.write("**πŸ“Š Model Summary:**")
@@ -524,25 +556,11 @@ def main():
524
  st.write(f"- **Dense Layers:** {len(analysis['dense_layers'])}")
525
  st.write(f"- **Other Layers:** {len(analysis['other_layers'])}")
526
 
527
- # Show layer details
528
- if analysis['conv_layers']:
529
- st.write("**🎯 Convolutional Layers (Best for Grad-CAM):**")
530
- for layer in analysis['conv_layers']:
531
- st.write(f" βœ… {layer['name']} ({layer['type']}) - Shape: {layer['output_shape']}")
532
-
533
- if analysis['dense_layers']:
534
- st.write("**🧠 Dense Layers:**")
535
- for layer in analysis['dense_layers']:
536
- st.write(f" β€’ {layer['name']} ({layer['type']}) - Shape: {layer['output_shape']}")
537
-
538
- # Test layer selection
539
- layer_name, layer_status = find_best_layer_for_gradcam(st.session_state.model)
540
- if "βœ…" in layer_status:
541
- st.markdown(f'<div class="status-box success"><strong>Grad-CAM Layer:</strong> {layer_status}</div>', unsafe_allow_html=True)
542
- elif "⚠️" in layer_status:
543
- st.markdown(f'<div class="status-box warning"><strong>Grad-CAM Layer:</strong> {layer_status}</div>', unsafe_allow_html=True)
544
- else:
545
- st.markdown(f'<div class="status-box error"><strong>Grad-CAM Layer:</strong> {layer_status}</div>', unsafe_allow_html=True)
546
 
547
  # Manual reload button
548
  if st.button("πŸ”„ Reload Model", help="Try to reload the model"):
@@ -564,7 +582,7 @@ def main():
564
  force_gradcam = st.checkbox(
565
  "Attempt Grad-CAM",
566
  value=True,
567
- help="Try Grad-CAM even with non-CNN models (experimental)"
568
  )
569
 
570
  colormap = st.selectbox(
@@ -577,18 +595,7 @@ def main():
577
  show_probabilities = st.checkbox("Show All Probabilities", value=True)
578
  show_debug = st.checkbox("Show Debug Info", value=True)
579
  show_stats = st.checkbox("Show Heatmap Statistics", value=True)
580
-
581
- st.markdown("---")
582
- st.header("🎨 Color Scheme Guide")
583
- st.info("""
584
- **hot**: Red-Yellow (classic heat)
585
- **jet**: Blue-Green-Yellow-Red
586
- **viridis**: Purple-Blue-Green-Yellow
587
- **plasma**: Purple-Pink-Yellow
588
- **inferno**: Black-Purple-Red-Yellow
589
- **magma**: Black-Purple-Pink-White
590
- **coolwarm**: Blue-White-Red
591
- """)
592
 
593
  if uploaded_file is not None:
594
  # Load image
@@ -630,12 +637,12 @@ def main():
630
  st.error("❌ Model not loaded. Check the debug information above to see available files.")
631
 
632
  with col2:
633
- st.subheader("🎯 Enhanced AI Attention Visualization")
634
 
635
  if st.session_state.model is not None and 'predictions' in locals() and predictions is not None:
636
- # Create enhanced overlay visualization
637
- with st.spinner("🎨 Generating enhanced attention visualization..."):
638
- result = create_enhanced_overlay_visualization(
639
  image,
640
  predictions,
641
  st.session_state.model,
@@ -646,6 +653,7 @@ def main():
646
  if result and len(result) >= 2:
647
  overlay_fig, status_message = result[0], result[1]
648
  stats = result[2] if len(result) > 2 else None
 
649
 
650
  if overlay_fig is not None:
651
  st.pyplot(overlay_fig)
@@ -658,20 +666,30 @@ def main():
658
  elif "⚠️" in status_message:
659
  st.warning(f"⚠️ {status_message}")
660
  else:
661
- st.info(f"ℹ️ {status_message}")
662
 
663
  # Show heatmap statistics
664
  if show_stats and stats:
665
  st.write("**πŸ“ˆ Heatmap Statistics:**")
666
- col_stats1, col_stats2 = st.columns(2)
667
- with col_stats1:
668
- st.write(f"β€’ Min: {stats['min']:.3f}")
669
- st.write(f"β€’ Max: {stats['max']:.3f}")
670
- with col_stats2:
671
- st.write(f"β€’ Mean: {stats['mean']:.3f}")
672
- st.write(f"β€’ Std: {stats['std']:.3f}")
 
 
 
 
 
 
 
 
673
  else:
674
  st.error(f"Could not generate visualization: {status_message}")
 
 
675
  else:
676
  st.error("Could not generate attention visualization")
677
  else:
@@ -680,22 +698,23 @@ def main():
680
  else:
681
  # Welcome message
682
  st.markdown("""
683
- ## πŸ‘‹ Welcome to the Enhanced Stroke Classification System
684
 
685
- This system now provides **better color visualization** and **enhanced contrast** for attention heatmaps.
686
 
687
- ### 🎨 New Visualization Features:
688
- - **Multiple Color Schemes**: Choose from 7 different color palettes
689
- - **Enhanced Contrast**: Better visibility of attention patterns
690
- - **Three-Panel View**: Original, heatmap, and overlay side-by-side
691
- - **Statistics Display**: See heatmap value distributions
692
 
693
- ### πŸš€ Why Colors Matter:
694
- - **Red/Yellow (hot)**: High attention areas
695
- - **Blue/Purple**: Low attention areas
696
- - **Enhanced contrast**: Makes subtle patterns visible
 
697
 
698
- **Get started by uploading an image! πŸ‘ˆ**
699
  """)
700
 
701
  # Medical disclaimer
 
129
  'conv_layers': [],
130
  'dense_layers': [],
131
  'other_layers': [],
132
+ 'all_layers_detailed': [],
133
  'model_type': 'Unknown'
134
  }
135
 
136
  for i, layer in enumerate(model.layers):
137
  layer_type = type(layer).__name__
138
+
139
+ # Get more detailed layer information
140
  layer_info = {
141
  'index': i,
142
  'name': layer.name,
143
  'type': layer_type,
144
+ 'output_shape': getattr(layer, 'output_shape', 'Unknown'),
145
+ 'trainable': getattr(layer, 'trainable', 'Unknown'),
146
+ 'activation': getattr(layer, 'activation', None)
147
  }
148
 
149
+ # Try to get activation function name
150
+ if hasattr(layer, 'activation') and layer.activation:
151
+ try:
152
+ layer_info['activation'] = layer.activation.__name__
153
+ except:
154
+ layer_info['activation'] = str(layer.activation)
155
+
156
+ layer_analysis['all_layers_detailed'].append(layer_info)
157
+
158
+ # Categorize layers with more comprehensive detection
159
  if any(conv_type in layer_type for conv_type in [
160
  'Conv1D', 'Conv2D', 'Conv3D', 'SeparableConv2D', 'DepthwiseConv2D',
161
  'Convolution1D', 'Convolution2D', 'Convolution3D'
162
+ ]) or 'conv' in layer.name.lower():
163
  layer_analysis['conv_layers'].append(layer_info)
 
164
 
165
  elif 'Dense' in layer_type or 'Linear' in layer_type:
166
  layer_analysis['dense_layers'].append(layer_info)
167
 
168
+ else:
 
 
 
 
169
  layer_analysis['other_layers'].append(layer_info)
 
 
 
 
170
 
171
  # Determine model type
172
  if layer_analysis['conv_layers']:
 
178
 
179
  return layer_analysis
180
 
181
+ def debug_gradcam_step_by_step(img_array, model, layer_name, pred_index):
182
+ """Debug Grad-CAM computation step by step."""
183
+ debug_info = {
184
+ 'step': 'Starting',
185
+ 'error': None,
186
+ 'layer_output_shape': None,
187
+ 'gradients_shape': None,
188
+ 'gradients_stats': None,
189
+ 'heatmap_stats': None
190
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
 
 
 
 
192
  try:
193
+ debug_info['step'] = 'Getting target layer'
194
  target_layer = model.get_layer(layer_name)
195
+ debug_info['target_layer_type'] = type(target_layer).__name__
196
 
197
+ debug_info['step'] = 'Creating grad model'
198
  grad_model = tf.keras.Model(
199
  inputs=[model.inputs],
200
  outputs=[target_layer.output, model.output]
201
  )
202
 
203
+ debug_info['step'] = 'Computing forward pass'
204
  with tf.GradientTape() as tape:
205
  layer_output, preds = grad_model(img_array)
206
+ debug_info['layer_output_shape'] = layer_output.shape.as_list()
207
+ debug_info['predictions_shape'] = preds.shape.as_list()
208
+
209
  if pred_index is None:
210
  pred_index = tf.argmax(preds[0])
211
+ debug_info['pred_index'] = int(pred_index)
212
+ debug_info['pred_confidence'] = float(preds[0][pred_index])
213
+
214
  class_channel = preds[:, pred_index]
215
+ debug_info['class_channel_shape'] = class_channel.shape.as_list()
216
 
217
+ debug_info['step'] = 'Computing gradients'
218
  grads = tape.gradient(class_channel, layer_output)
219
 
220
  if grads is None:
221
+ debug_info['error'] = "Gradients are None - no backpropagation path"
222
+ return None, debug_info
223
+
224
+ debug_info['gradients_shape'] = grads.shape.as_list()
225
+ debug_info['gradients_stats'] = {
226
+ 'min': float(tf.reduce_min(grads)),
227
+ 'max': float(tf.reduce_max(grads)),
228
+ 'mean': float(tf.reduce_mean(grads)),
229
+ 'std': float(tf.math.reduce_std(grads))
230
+ }
231
 
232
+ debug_info['step'] = 'Processing gradients based on layer type'
233
+
234
+ if len(layer_output.shape) == 4: # Conv layer
235
+ debug_info['processing_type'] = 'Convolutional layer (4D)'
236
  pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
237
  layer_output = layer_output[0]
238
  heatmap = layer_output @ pooled_grads[..., tf.newaxis]
239
  heatmap = tf.squeeze(heatmap)
240
+
241
+ elif len(layer_output.shape) == 2: # Dense layer
242
+ debug_info['processing_type'] = 'Dense layer (2D)'
243
+ # For dense layers, create spatial heatmap from gradient magnitude
244
+ grads_magnitude = tf.reduce_mean(tf.abs(grads))
245
+ # Create a simple spatial pattern
246
+ heatmap = tf.ones((14, 14)) * grads_magnitude
247
+
248
  else:
249
+ debug_info['error'] = f"Unsupported layer shape: {layer_output.shape}"
250
+ return None, debug_info
251
+
252
+ debug_info['step'] = 'Normalizing heatmap'
253
+ debug_info['raw_heatmap_stats'] = {
254
+ 'min': float(tf.reduce_min(heatmap)),
255
+ 'max': float(tf.reduce_max(heatmap)),
256
+ 'mean': float(tf.reduce_mean(heatmap)),
257
+ 'std': float(tf.math.reduce_std(heatmap))
258
+ }
259
 
260
+ # Apply ReLU (remove negative values)
261
  heatmap = tf.maximum(heatmap, 0)
 
 
262
 
263
+ # Normalize
264
+ heatmap_max = tf.reduce_max(heatmap)
265
+ if heatmap_max > 0:
266
+ heatmap = heatmap / heatmap_max
267
+ else:
268
+ debug_info['error'] = "All heatmap values are zero or negative"
269
+ return None, debug_info
270
+
271
+ debug_info['final_heatmap_stats'] = {
272
+ 'min': float(tf.reduce_min(heatmap)),
273
+ 'max': float(tf.reduce_max(heatmap)),
274
+ 'mean': float(tf.reduce_mean(heatmap)),
275
+ 'std': float(tf.math.reduce_std(heatmap))
276
+ }
277
+
278
+ debug_info['step'] = 'Complete'
279
+ return heatmap.numpy(), debug_info
280
 
281
  except Exception as e:
282
+ debug_info['error'] = f"Exception in step '{debug_info['step']}': {str(e)}"
283
+ return None, debug_info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
+ def create_robust_gradcam_heatmap(img, model, predictions):
286
+ """Create Grad-CAM with comprehensive debugging."""
287
  try:
288
  # Preprocess image
289
  img_resized = img.resize((224, 224))
 
296
  # Normalize and add batch dimension
297
  img_array = np.expand_dims(img_array, axis=0) / 255.0
298
 
299
+ # Get model analysis
300
+ analysis = analyze_model_architecture(model)
301
 
302
+ # Try different layers in order of preference
303
+ layer_candidates = []
304
 
305
+ # Add conv layers first
306
+ for layer in analysis['conv_layers']:
307
+ layer_candidates.append((layer['name'], f"Conv layer: {layer['name']}"))
 
 
 
 
308
 
309
+ # Add other potentially suitable layers
310
+ for layer in analysis['all_layers_detailed']:
311
+ if (layer['type'] in ['Activation', 'BatchNormalization'] and
312
+ isinstance(layer['output_shape'], (list, tuple)) and
313
+ len(layer['output_shape']) == 4):
314
+ layer_candidates.append((layer['name'], f"4D layer: {layer['name']} ({layer['type']})"))
315
 
316
+ # Try dense layers as last resort
317
+ if not layer_candidates:
318
+ for layer in analysis['dense_layers']:
319
+ layer_candidates.append((layer['name'], f"Dense layer: {layer['name']} (experimental)"))
320
+
321
+ if not layer_candidates:
322
+ return None, "❌ No suitable layers found", None
323
+
324
+ # Try each candidate layer
325
+ for layer_name, layer_desc in layer_candidates:
326
+ pred_index = np.argmax(predictions)
 
 
 
 
 
 
 
 
 
327
 
328
+ heatmap, debug_info = debug_gradcam_step_by_step(
329
+ img_array, model, layer_name, pred_index
330
+ )
331
 
332
+ if heatmap is not None:
333
+ # Resize heatmap to match input image size
334
+ if heatmap.shape[0] != 224 or heatmap.shape[1] != 224:
335
+ heatmap_resized = tf.image.resize(
336
+ heatmap[..., tf.newaxis],
337
+ (224, 224)
338
+ ).numpy()[:, :, 0]
339
+ else:
340
+ heatmap_resized = heatmap
341
+
342
+ # Final statistics
343
+ stats = {
344
+ 'min': float(np.min(heatmap_resized)),
345
+ 'max': float(np.max(heatmap_resized)),
346
+ 'mean': float(np.mean(heatmap_resized)),
347
+ 'std': float(np.std(heatmap_resized))
348
+ }
349
+
350
+ return heatmap_resized, f"βœ… Grad-CAM successful using {layer_desc}", stats, debug_info
351
+ else:
352
+ # Continue to next layer if this one failed
353
+ continue
354
+
355
+ # If all layers failed, return debug info from the last attempt
356
+ return None, f"❌ All layers failed. Last error: {debug_info.get('error', 'Unknown')}", None, debug_info
357
+
358
  except Exception as e:
359
+ return None, f"❌ Grad-CAM error: {str(e)}", None, {'error': str(e)}
360
 
361
  def predict_stroke(img, model):
362
  """Predict stroke type from image."""
 
383
  except Exception as e:
384
  return None, f"Prediction error: {str(e)}"
385
 
386
+ def create_enhanced_simulated_heatmap(img, predictions):
387
+ """Create a more realistic simulated heatmap."""
388
  try:
389
  confidence = np.max(predictions)
390
+ predicted_class = np.argmax(predictions)
391
+
392
+ # Create different patterns based on predicted class
393
+ if predicted_class == 0: # Hemorrhagic
394
+ # Focus on center-left region
395
+ center_x, center_y = 80, 112
396
+ elif predicted_class == 1: # Ischemic
397
+ # Focus on right side
398
+ center_x, center_y = 150, 112
399
+ else: # No stroke
400
+ # Diffuse, low-intensity pattern
401
  center_x, center_y = 112, 112
 
 
 
402
 
403
+ # Create base pattern
404
+ y, x = np.ogrid[:224, :224]
405
+
406
+ # Primary focus area
407
+ mask1 = np.exp(-((x - center_x)**2 + (y - center_y)**2) / (2 * (40**2)))
408
+
409
+ # Secondary areas
410
+ mask2 = np.exp(-((x - center_x + 30)**2 + (y - center_y + 20)**2) / (2 * (25**2)))
411
+ mask3 = np.exp(-((x - center_x - 20)**2 + (y - center_y - 30)**2) / (2 * (30**2)))
412
+
413
+ # Combine patterns
414
+ heatmap = (mask1 * 0.8 + mask2 * 0.4 + mask3 * 0.3) * confidence
415
+
416
+ # Add some noise for realism
417
+ np.random.seed(42)
418
+ noise = np.random.normal(0, 0.05, heatmap.shape)
419
+ heatmap = np.maximum(heatmap + noise, 0)
420
+
421
+ # Normalize
422
+ if np.max(heatmap) > 0:
423
+ heatmap = heatmap / np.max(heatmap)
424
 
425
  stats = {
426
+ 'min': float(np.min(heatmap)),
427
+ 'max': float(np.max(heatmap)),
428
+ 'mean': float(np.mean(heatmap)),
429
+ 'std': float(np.std(heatmap))
430
  }
431
 
432
+ return heatmap, "⚠️ Using enhanced simulated heatmap", stats
433
  except Exception as e:
434
  return None, f"❌ Simulated heatmap error: {str(e)}", None
435
 
436
+ def create_comprehensive_visualization(img, predictions, model, force_gradcam=True, colormap='hot'):
437
+ """Create comprehensive visualization with debugging."""
438
  if not MPL_AVAILABLE:
439
  return None, "❌ Matplotlib not available"
440
 
 
446
  heatmap = None
447
  status_message = ""
448
  stats = None
449
+ debug_info = None
450
 
451
  # Try Grad-CAM first
452
  if force_gradcam and model is not None:
453
+ result = create_robust_gradcam_heatmap(img, model, predictions)
454
+ if result and len(result) >= 3:
455
+ heatmap, gradcam_status, stats = result[0], result[1], result[2]
456
+ if len(result) > 3:
457
+ debug_info = result[3]
458
  status_message = gradcam_status
459
 
460
+ # Fallback to enhanced simulated if Grad-CAM failed
461
  if heatmap is None:
462
+ result = create_enhanced_simulated_heatmap(img, predictions)
463
  if result and len(result) == 3:
464
  heatmap, sim_status, stats = result
465
  if status_message:
 
468
  status_message = sim_status
469
 
470
  if heatmap is None:
471
+ return None, "❌ Could not generate any heatmap", None, None
472
 
473
+ # Create visualization
474
  fig, axes = plt.subplots(1, 3, figsize=(15, 5))
475
 
476
  # 1. Original image
 
478
  axes[0].set_title("Original Image", fontsize=12, fontweight='bold')
479
  axes[0].axis('off')
480
 
481
+ # 2. Heatmap only
482
  im1 = axes[1].imshow(heatmap, cmap=colormap, vmin=0, vmax=1)
483
  axes[1].set_title(f"Attention Heatmap ({colormap})", fontsize=12, fontweight='bold')
484
  axes[1].axis('off')
485
  plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
486
 
487
+ # 3. Overlay
488
  axes[2].imshow(img_array)
489
  im2 = axes[2].imshow(heatmap, cmap=colormap, alpha=0.6, vmin=0, vmax=1, interpolation='bilinear')
490
 
 
502
 
503
  plt.tight_layout()
504
 
505
+ return fig, status_message, stats, debug_info
506
 
507
  except Exception as e:
508
+ return None, f"❌ Visualization error: {str(e)}", None, None
509
 
510
  # Main App
511
  def main():
 
546
 
547
  # Enhanced model architecture analysis
548
  if st.session_state.model is not None:
549
+ with st.expander("πŸ” Detailed Model Architecture Analysis"):
550
  analysis = analyze_model_architecture(st.session_state.model)
551
 
552
  st.write("**πŸ“Š Model Summary:**")
 
556
  st.write(f"- **Dense Layers:** {len(analysis['dense_layers'])}")
557
  st.write(f"- **Other Layers:** {len(analysis['other_layers'])}")
558
 
559
+ # Show detailed layer information
560
+ st.write("**πŸ” All Layers (Detailed):**")
561
+ for layer in analysis['all_layers_detailed']:
562
+ activation_info = f" | Activation: {layer['activation']}" if layer['activation'] else ""
563
+ st.code(f"{layer['index']:2d}: {layer['name']} ({layer['type']}) | Shape: {layer['output_shape']}{activation_info}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
 
565
  # Manual reload button
566
  if st.button("πŸ”„ Reload Model", help="Try to reload the model"):
 
582
  force_gradcam = st.checkbox(
583
  "Attempt Grad-CAM",
584
  value=True,
585
+ help="Try Grad-CAM with comprehensive debugging"
586
  )
587
 
588
  colormap = st.selectbox(
 
595
  show_probabilities = st.checkbox("Show All Probabilities", value=True)
596
  show_debug = st.checkbox("Show Debug Info", value=True)
597
  show_stats = st.checkbox("Show Heatmap Statistics", value=True)
598
+ show_detailed_debug = st.checkbox("Show Detailed Debug Info", value=False)
 
 
 
 
 
 
 
 
 
 
 
599
 
600
  if uploaded_file is not None:
601
  # Load image
 
637
  st.error("❌ Model not loaded. Check the debug information above to see available files.")
638
 
639
  with col2:
640
+ st.subheader("🎯 Comprehensive AI Attention Visualization")
641
 
642
  if st.session_state.model is not None and 'predictions' in locals() and predictions is not None:
643
+ # Create comprehensive visualization
644
+ with st.spinner("🎨 Generating comprehensive attention visualization..."):
645
+ result = create_comprehensive_visualization(
646
  image,
647
  predictions,
648
  st.session_state.model,
 
653
  if result and len(result) >= 2:
654
  overlay_fig, status_message = result[0], result[1]
655
  stats = result[2] if len(result) > 2 else None
656
+ debug_info = result[3] if len(result) > 3 else None
657
 
658
  if overlay_fig is not None:
659
  st.pyplot(overlay_fig)
 
666
  elif "⚠️" in status_message:
667
  st.warning(f"⚠️ {status_message}")
668
  else:
669
+ st.error(f"❌ {status_message}")
670
 
671
  # Show heatmap statistics
672
  if show_stats and stats:
673
  st.write("**πŸ“ˆ Heatmap Statistics:**")
674
+ if any(np.isnan([stats['min'], stats['max'], stats['mean'], stats['std']])):
675
+ st.error("⚠️ NaN values detected in heatmap - this indicates a computation error")
676
+ else:
677
+ col_stats1, col_stats2 = st.columns(2)
678
+ with col_stats1:
679
+ st.write(f"β€’ Min: {stats['min']:.3f}")
680
+ st.write(f"β€’ Max: {stats['max']:.3f}")
681
+ with col_stats2:
682
+ st.write(f"β€’ Mean: {stats['mean']:.3f}")
683
+ st.write(f"β€’ Std: {stats['std']:.3f}")
684
+
685
+ # Show detailed debug information
686
+ if show_detailed_debug and debug_info:
687
+ with st.expander("πŸ”§ Detailed Debug Information"):
688
+ st.json(debug_info)
689
  else:
690
  st.error(f"Could not generate visualization: {status_message}")
691
+ if debug_info:
692
+ st.error(f"Debug info: {debug_info.get('error', 'No additional info')}")
693
  else:
694
  st.error("Could not generate attention visualization")
695
  else:
 
698
  else:
699
  # Welcome message
700
  st.markdown("""
701
+ ## πŸ‘‹ Welcome to the Comprehensive Stroke Classification System
702
 
703
+ This system now includes **step-by-step debugging** to identify why Grad-CAM might be failing.
704
 
705
+ ### πŸ”§ New Debugging Features:
706
+ - **Step-by-step Grad-CAM debugging** - See exactly where it fails
707
+ - **Multiple layer attempts** - Tries different layers automatically
708
+ - **Enhanced error messages** - Clear explanations of what went wrong
709
+ - **NaN detection** - Identifies computation errors
710
 
711
+ ### 🎯 What to Look For:
712
+ - **Green success messages** - Grad-CAM is working
713
+ - **Orange warnings** - Using fallback methods
714
+ - **Red errors** - Something is broken
715
+ - **NaN statistics** - Computation failure
716
 
717
+ **Upload an image to see detailed debugging! πŸ‘ˆ**
718
  """)
719
 
720
  # Medical disclaimer