bakhili commited on
Commit
3130041
Β·
verified Β·
1 Parent(s): 9735798

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +181 -82
src/streamlit_app.py CHANGED
@@ -119,83 +119,156 @@ def load_stroke_model():
119
  except Exception as e:
120
  return None, f"❌ Model loading failed: {str(e)}"
121
 
122
- def debug_model_layers(model):
123
- """Debug function to show all model layers."""
124
  if model is None:
125
- return "No model loaded"
126
 
127
- layer_info = []
128
- conv_layers = []
 
 
 
 
 
 
129
 
130
  for i, layer in enumerate(model.layers):
131
  layer_type = type(layer).__name__
132
- layer_info.append(f"{i}: {layer.name} ({layer_type})")
133
-
134
- if 'conv' in layer.name.lower() or 'Conv' in layer_type:
135
- conv_layers.append(f"βœ… {layer.name} ({layer_type})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
- return {
138
- 'all_layers': layer_info,
139
- 'conv_layers': conv_layers,
140
- 'total_layers': len(model.layers)
141
- }
 
 
 
 
142
 
143
- def find_best_conv_layer(model):
144
- """Find the best convolutional layer for Grad-CAM."""
145
  if model is None:
146
  return None, "No model loaded"
147
 
148
- conv_layers = []
149
 
150
- # Look for convolutional layers
151
- for layer in model.layers:
152
- layer_type = type(layer).__name__
153
- if any(conv_type in layer_type for conv_type in ['Conv2D', 'Conv', 'SeparableConv2D']):
154
- conv_layers.append(layer.name)
 
 
 
 
 
 
155
 
156
- if not conv_layers:
157
- return None, "No convolutional layers found"
 
 
 
 
 
 
158
 
159
- # Return the last convolutional layer
160
- return conv_layers[-1], f"Found {len(conv_layers)} conv layers, using: {conv_layers[-1]}"
 
 
 
 
 
 
 
 
 
161
 
162
- def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
163
- """Generate Grad-CAM heatmap."""
164
  try:
165
- # Create a model that maps the input image to the activations of the last conv layer
 
 
 
166
  grad_model = tf.keras.Model(
167
  inputs=[model.inputs],
168
- outputs=[model.get_layer(last_conv_layer_name).output, model.output]
169
  )
170
 
171
  # Compute the gradient of the top predicted class
172
  with tf.GradientTape() as tape:
173
- last_conv_layer_output, preds = grad_model(img_array)
174
  if pred_index is None:
175
  pred_index = tf.argmax(preds[0])
176
  class_channel = preds[:, pred_index]
177
 
178
- # Gradient of the output neuron with regard to the output feature map
179
- grads = tape.gradient(class_channel, last_conv_layer_output)
180
 
181
- # Mean intensity of the gradient over a specific feature map channel
182
- pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
183
 
184
- # Multiply each channel by "how important this channel is"
185
- last_conv_layer_output = last_conv_layer_output[0]
186
- heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
187
- heatmap = tf.squeeze(heatmap)
 
 
 
188
 
189
- # Normalize the heatmap between 0 & 1
190
- heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
  return heatmap.numpy(), None
193
 
194
  except Exception as e:
195
- return None, f"Grad-CAM error: {str(e)}"
196
 
197
  def create_real_gradcam_heatmap(img, model, predictions):
198
- """Create a real Grad-CAM heatmap."""
199
  try:
200
  # Preprocess image
201
  img_resized = img.resize((224, 224))
@@ -208,31 +281,34 @@ def create_real_gradcam_heatmap(img, model, predictions):
208
  # Normalize and add batch dimension
209
  img_array = np.expand_dims(img_array, axis=0) / 255.0
210
 
211
- # Find the best convolutional layer
212
- conv_layer_name, layer_status = find_best_conv_layer(model)
213
 
214
- if conv_layer_name is None:
215
  return None, f"❌ {layer_status}"
216
 
217
  # Generate Grad-CAM heatmap
218
  heatmap, error = make_gradcam_heatmap(
219
  img_array,
220
  model,
221
- conv_layer_name,
222
  pred_index=np.argmax(predictions)
223
  )
224
 
225
  if error:
226
- return None, f"❌ {error}"
227
 
228
  if heatmap is not None:
229
  # Resize heatmap to match input image size
230
- heatmap_resized = tf.image.resize(
231
- heatmap[..., tf.newaxis],
232
- (224, 224)
233
- ).numpy()[:, :, 0]
 
 
 
234
 
235
- return heatmap_resized, f"βœ… Grad-CAM successful using layer: {conv_layer_name}"
236
  else:
237
  return None, "❌ Failed to generate heatmap"
238
 
@@ -380,29 +456,50 @@ def main():
380
  # Model status details
381
  st.markdown(f'<div class="status-box info"><strong>Model Status:</strong> {st.session_state.model_status}</div>', unsafe_allow_html=True)
382
 
383
- # Debug model architecture
384
  if st.session_state.model is not None:
385
- with st.expander("πŸ” Model Architecture Debug"):
386
- debug_info = debug_model_layers(st.session_state.model)
387
 
388
  st.write("**πŸ“Š Model Summary:**")
389
- st.write(f"- Total layers: {debug_info['total_layers']}")
390
- st.write(f"- Convolutional layers found: {len(debug_info['conv_layers'])}")
 
 
 
391
 
392
- if debug_info['conv_layers']:
393
- st.write("**🎯 Available Convolutional Layers:**")
394
- for conv_layer in debug_info['conv_layers']:
395
- st.write(f" {conv_layer}")
396
-
397
- # Test which layer we'll use
398
- conv_layer_name, layer_status = find_best_conv_layer(st.session_state.model)
399
- st.markdown(f'<div class="status-box info"><strong>Selected for Grad-CAM:</strong> {layer_status}</div>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  else:
401
- st.markdown('<div class="status-box error">❌ No convolutional layers found - Grad-CAM will not work</div>', unsafe_allow_html=True)
402
 
403
- with st.expander("All Layers (Advanced)"):
404
- for layer_info in debug_info['all_layers']:
405
- st.code(layer_info)
 
 
406
 
407
  # Manual reload button
408
  if st.button("πŸ”„ Reload Model", help="Try to reload the model"):
@@ -422,9 +519,9 @@ def main():
422
  st.header("🎨 Visualization Options")
423
 
424
  force_gradcam = st.checkbox(
425
- "Force Grad-CAM Attempt",
426
  value=True,
427
- help="Always try Grad-CAM first (recommended)"
428
  )
429
 
430
  show_probabilities = st.checkbox("Show All Probabilities", value=True)
@@ -433,7 +530,7 @@ def main():
433
  st.markdown("---")
434
  st.header("ℹ️ About")
435
  st.info("""
436
- **Model Architecture:** Deep Learning CNN
437
 
438
  **Classes:**
439
  - Hemorrhagic Stroke
@@ -442,7 +539,7 @@ def main():
442
 
443
  **Input:** 224Γ—224 RGB images
444
 
445
- **Attention Method:** Grad-CAM
446
  """)
447
 
448
  if uploaded_file is not None:
@@ -507,8 +604,10 @@ def main():
507
  if show_debug:
508
  if "βœ… Grad-CAM successful" in status_message:
509
  st.success(f"βœ… {status_message}")
510
- else:
511
  st.warning(f"⚠️ {status_message}")
 
 
512
  else:
513
  st.error(f"Could not generate visualization: {status_message}")
514
  else:
@@ -521,19 +620,19 @@ def main():
521
  st.markdown("""
522
  ## πŸ‘‹ Welcome to the Stroke Classification System
523
 
524
- This AI system analyzes brain scan images and shows you **exactly where the AI is looking**.
525
 
526
  ### πŸš€ Features:
527
- - **Deep Learning Classification**: Advanced CNN architecture
528
- - **Real AI Attention Maps**: See actual model reasoning with Grad-CAM
529
- - **Debug Information**: Understand why Grad-CAM works or fails
530
- - **Transparent AI**: Full visibility into the decision process
531
 
532
  ### πŸ“‹ How to Use:
533
- 1. **Check system status** and **model debug info** above
534
  2. **Upload a brain scan image** using the sidebar
535
  3. **View classification results** with confidence scores
536
- 4. **Explore attention visualization** - it will tell you if it's real or simulated
537
 
538
  **Get started by uploading an image! πŸ‘ˆ**
539
  """)
 
119
  except Exception as e:
120
  return None, f"❌ Model loading failed: {str(e)}"
121
 
122
+ def analyze_model_architecture(model):
123
+ """Comprehensive analysis of model architecture."""
124
  if model is None:
125
+ return {"error": "No model loaded"}
126
 
127
+ layer_analysis = {
128
+ 'total_layers': len(model.layers),
129
+ 'conv_layers': [],
130
+ 'dense_layers': [],
131
+ 'other_layers': [],
132
+ 'potential_gradcam_layers': [],
133
+ 'model_type': 'Unknown'
134
+ }
135
 
136
  for i, layer in enumerate(model.layers):
137
  layer_type = type(layer).__name__
138
+ layer_info = {
139
+ 'index': i,
140
+ 'name': layer.name,
141
+ 'type': layer_type,
142
+ 'output_shape': getattr(layer, 'output_shape', 'Unknown')
143
+ }
144
+
145
+ # Categorize layers
146
+ if any(conv_type in layer_type for conv_type in [
147
+ 'Conv1D', 'Conv2D', 'Conv3D', 'SeparableConv2D', 'DepthwiseConv2D',
148
+ 'Convolution1D', 'Convolution2D', 'Convolution3D'
149
+ ]):
150
+ layer_analysis['conv_layers'].append(layer_info)
151
+ layer_analysis['potential_gradcam_layers'].append(layer_info)
152
+
153
+ elif 'Dense' in layer_type or 'Linear' in layer_type:
154
+ layer_analysis['dense_layers'].append(layer_info)
155
+
156
+ # Check for other layer types that might work with Grad-CAM
157
+ elif any(layer_name in layer_type for layer_name in [
158
+ 'Activation', 'BatchNormalization', 'Dropout', 'MaxPooling', 'AveragePooling',
159
+ 'GlobalMaxPooling', 'GlobalAveragePooling', 'Flatten', 'Reshape'
160
+ ]):
161
+ layer_analysis['other_layers'].append(layer_info)
162
+
163
+ # Some of these might be suitable for Grad-CAM if they have spatial dimensions
164
+ if any(pool_type in layer_type for pool_type in ['MaxPooling2D', 'AveragePooling2D']):
165
+ layer_analysis['potential_gradcam_layers'].append(layer_info)
166
 
167
+ # Determine model type
168
+ if layer_analysis['conv_layers']:
169
+ layer_analysis['model_type'] = 'CNN (Convolutional Neural Network)'
170
+ elif layer_analysis['dense_layers']:
171
+ layer_analysis['model_type'] = 'MLP (Multi-Layer Perceptron)'
172
+ else:
173
+ layer_analysis['model_type'] = 'Custom Architecture'
174
+
175
+ return layer_analysis
176
 
177
+ def find_best_layer_for_gradcam(model):
178
+ """Find the best layer for Grad-CAM with expanded search."""
179
  if model is None:
180
  return None, "No model loaded"
181
 
182
+ analysis = analyze_model_architecture(model)
183
 
184
+ # Priority 1: Convolutional layers (best for Grad-CAM)
185
+ if analysis['conv_layers']:
186
+ best_layer = analysis['conv_layers'][-1] # Last conv layer
187
+ return best_layer['name'], f"βœ… Using convolutional layer: {best_layer['name']} ({best_layer['type']})"
188
+
189
+ # Priority 2: Pooling layers with spatial dimensions
190
+ pooling_layers = [layer for layer in analysis['other_layers']
191
+ if any(pool_type in layer['type'] for pool_type in ['MaxPooling2D', 'AveragePooling2D'])]
192
+ if pooling_layers:
193
+ best_layer = pooling_layers[-1]
194
+ return best_layer['name'], f"⚠️ Using pooling layer: {best_layer['name']} ({best_layer['type']}) - may not work well"
195
 
196
+ # Priority 3: Try any layer with 4D output (batch, height, width, channels)
197
+ for layer in reversed(model.layers):
198
+ try:
199
+ output_shape = layer.output_shape
200
+ if isinstance(output_shape, tuple) and len(output_shape) == 4:
201
+ return layer.name, f"⚠️ Trying 4D layer: {layer.name} ({type(layer).__name__}) - experimental"
202
+ except:
203
+ continue
204
 
205
+ # Priority 4: Try the layer before the last dense layer
206
+ if analysis['dense_layers'] and len(model.layers) > 2:
207
+ # Find the layer just before the first dense layer
208
+ for i, layer in enumerate(model.layers):
209
+ if 'Dense' in type(layer).__name__:
210
+ if i > 0:
211
+ prev_layer = model.layers[i-1]
212
+ return prev_layer.name, f"⚠️ Using pre-dense layer: {prev_layer.name} ({type(prev_layer).__name__}) - may not work"
213
+ break
214
+
215
+ return None, "❌ No suitable layers found for Grad-CAM"
216
 
217
+ def make_gradcam_heatmap(img_array, model, layer_name, pred_index=None):
218
+ """Generate Grad-CAM heatmap with better error handling."""
219
  try:
220
+ # Get the target layer
221
+ target_layer = model.get_layer(layer_name)
222
+
223
+ # Create a model that maps the input image to the activations of the target layer
224
  grad_model = tf.keras.Model(
225
  inputs=[model.inputs],
226
+ outputs=[target_layer.output, model.output]
227
  )
228
 
229
  # Compute the gradient of the top predicted class
230
  with tf.GradientTape() as tape:
231
+ layer_output, preds = grad_model(img_array)
232
  if pred_index is None:
233
  pred_index = tf.argmax(preds[0])
234
  class_channel = preds[:, pred_index]
235
 
236
+ # Get gradients
237
+ grads = tape.gradient(class_channel, layer_output)
238
 
239
+ if grads is None:
240
+ return None, "Gradients are None - layer may not be differentiable"
241
 
242
+ # Handle different layer output shapes
243
+ if len(layer_output.shape) == 4: # Standard conv layer (batch, height, width, channels)
244
+ # Standard Grad-CAM for conv layers
245
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
246
+ layer_output = layer_output[0]
247
+ heatmap = layer_output @ pooled_grads[..., tf.newaxis]
248
+ heatmap = tf.squeeze(heatmap)
249
 
250
+ elif len(layer_output.shape) == 2: # Dense layer (batch, features)
251
+ # For dense layers, create a simple attention based on gradients
252
+ grads_abs = tf.abs(grads[0])
253
+ heatmap = tf.reduce_mean(grads_abs)
254
+ # Create a uniform heatmap since dense layers don't have spatial structure
255
+ heatmap = tf.ones((7, 7)) * heatmap # Small heatmap to be resized later
256
+
257
+ else:
258
+ return None, f"Unsupported layer output shape: {layer_output.shape}"
259
+
260
+ # Normalize the heatmap
261
+ heatmap = tf.maximum(heatmap, 0)
262
+ if tf.reduce_max(heatmap) > 0:
263
+ heatmap = heatmap / tf.reduce_max(heatmap)
264
 
265
  return heatmap.numpy(), None
266
 
267
  except Exception as e:
268
+ return None, f"Grad-CAM computation error: {str(e)}"
269
 
270
  def create_real_gradcam_heatmap(img, model, predictions):
271
+ """Create a real Grad-CAM heatmap with fallback strategies."""
272
  try:
273
  # Preprocess image
274
  img_resized = img.resize((224, 224))
 
281
  # Normalize and add batch dimension
282
  img_array = np.expand_dims(img_array, axis=0) / 255.0
283
 
284
+ # Find the best layer for Grad-CAM
285
+ layer_name, layer_status = find_best_layer_for_gradcam(model)
286
 
287
+ if layer_name is None:
288
  return None, f"❌ {layer_status}"
289
 
290
  # Generate Grad-CAM heatmap
291
  heatmap, error = make_gradcam_heatmap(
292
  img_array,
293
  model,
294
+ layer_name,
295
  pred_index=np.argmax(predictions)
296
  )
297
 
298
  if error:
299
+ return None, f"❌ {error} (Layer: {layer_name})"
300
 
301
  if heatmap is not None:
302
  # Resize heatmap to match input image size
303
+ if heatmap.shape[0] < 224 or heatmap.shape[1] < 224:
304
+ heatmap_resized = tf.image.resize(
305
+ heatmap[..., tf.newaxis],
306
+ (224, 224)
307
+ ).numpy()[:, :, 0]
308
+ else:
309
+ heatmap_resized = heatmap
310
 
311
+ return heatmap_resized, f"βœ… Grad-CAM successful using layer: {layer_name}"
312
  else:
313
  return None, "❌ Failed to generate heatmap"
314
 
 
456
  # Model status details
457
  st.markdown(f'<div class="status-box info"><strong>Model Status:</strong> {st.session_state.model_status}</div>', unsafe_allow_html=True)
458
 
459
+ # Enhanced model architecture analysis
460
  if st.session_state.model is not None:
461
+ with st.expander("πŸ” Enhanced Model Architecture Analysis"):
462
+ analysis = analyze_model_architecture(st.session_state.model)
463
 
464
  st.write("**πŸ“Š Model Summary:**")
465
+ st.write(f"- **Model Type:** {analysis['model_type']}")
466
+ st.write(f"- **Total Layers:** {analysis['total_layers']}")
467
+ st.write(f"- **Convolutional Layers:** {len(analysis['conv_layers'])}")
468
+ st.write(f"- **Dense Layers:** {len(analysis['dense_layers'])}")
469
+ st.write(f"- **Other Layers:** {len(analysis['other_layers'])}")
470
 
471
+ # Show layer details
472
+ if analysis['conv_layers']:
473
+ st.write("**🎯 Convolutional Layers (Best for Grad-CAM):**")
474
+ for layer in analysis['conv_layers']:
475
+ st.write(f" βœ… {layer['name']} ({layer['type']}) - Shape: {layer['output_shape']}")
476
+
477
+ if analysis['dense_layers']:
478
+ st.write("**🧠 Dense Layers:**")
479
+ for layer in analysis['dense_layers']:
480
+ st.write(f" β€’ {layer['name']} ({layer['type']}) - Shape: {layer['output_shape']}")
481
+
482
+ if analysis['other_layers']:
483
+ st.write("**βš™οΈ Other Layers:**")
484
+ for layer in analysis['other_layers'][:5]: # Show first 5
485
+ st.write(f" β€’ {layer['name']} ({layer['type']}) - Shape: {layer['output_shape']}")
486
+ if len(analysis['other_layers']) > 5:
487
+ st.write(f" ... and {len(analysis['other_layers']) - 5} more")
488
+
489
+ # Test layer selection
490
+ layer_name, layer_status = find_best_layer_for_gradcam(st.session_state.model)
491
+ if "βœ…" in layer_status:
492
+ st.markdown(f'<div class="status-box success"><strong>Grad-CAM Layer:</strong> {layer_status}</div>', unsafe_allow_html=True)
493
+ elif "⚠️" in layer_status:
494
+ st.markdown(f'<div class="status-box warning"><strong>Grad-CAM Layer:</strong> {layer_status}</div>', unsafe_allow_html=True)
495
  else:
496
+ st.markdown(f'<div class="status-box error"><strong>Grad-CAM Layer:</strong> {layer_status}</div>', unsafe_allow_html=True)
497
 
498
+ # Show model architecture recommendation
499
+ if analysis['model_type'] == 'MLP (Multi-Layer Perceptron)':
500
+ st.warning("⚠️ **Note:** Your model appears to be a Multi-Layer Perceptron (MLP) without convolutional layers. Grad-CAM works best with CNNs. The visualization may be limited or use experimental methods.")
501
+ elif analysis['model_type'] == 'Custom Architecture':
502
+ st.info("ℹ️ **Note:** Your model has a custom architecture. Grad-CAM compatibility depends on the specific layers used.")
503
 
504
  # Manual reload button
505
  if st.button("πŸ”„ Reload Model", help="Try to reload the model"):
 
519
  st.header("🎨 Visualization Options")
520
 
521
  force_gradcam = st.checkbox(
522
+ "Attempt Grad-CAM",
523
  value=True,
524
+ help="Try Grad-CAM even with non-CNN models (experimental)"
525
  )
526
 
527
  show_probabilities = st.checkbox("Show All Probabilities", value=True)
 
530
  st.markdown("---")
531
  st.header("ℹ️ About")
532
  st.info("""
533
+ **Model Architecture:** Detected automatically
534
 
535
  **Classes:**
536
  - Hemorrhagic Stroke
 
539
 
540
  **Input:** 224Γ—224 RGB images
541
 
542
+ **Attention Method:** Grad-CAM (when possible)
543
  """)
544
 
545
  if uploaded_file is not None:
 
604
  if show_debug:
605
  if "βœ… Grad-CAM successful" in status_message:
606
  st.success(f"βœ… {status_message}")
607
+ elif "⚠️" in status_message:
608
  st.warning(f"⚠️ {status_message}")
609
+ else:
610
+ st.info(f"ℹ️ {status_message}")
611
  else:
612
  st.error(f"Could not generate visualization: {status_message}")
613
  else:
 
620
  st.markdown("""
621
  ## πŸ‘‹ Welcome to the Stroke Classification System
622
 
623
+ This AI system analyzes brain scan images and attempts to show you where the AI focuses its attention.
624
 
625
  ### πŸš€ Features:
626
+ - **Automatic Architecture Detection**: Identifies your model type
627
+ - **Smart Layer Selection**: Finds the best layer for attention visualization
628
+ - **Fallback Strategies**: Works with different model architectures
629
+ - **Transparent Process**: Shows exactly what's happening
630
 
631
  ### πŸ“‹ How to Use:
632
+ 1. **Check the Enhanced Model Architecture Analysis** above
633
  2. **Upload a brain scan image** using the sidebar
634
  3. **View classification results** with confidence scores
635
+ 4. **Explore attention visualization** - real or simulated based on your model
636
 
637
  **Get started by uploading an image! πŸ‘ˆ**
638
  """)