stroke-classification / src /streamlit_app.py
bakhili's picture
Update src/streamlit_app.py
93c1900 verified
raw
history blame
28.7 kB
import streamlit as st
import numpy as np
import os
import sys
from PIL import Image
# Set environment variables to fix permission issues
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
os.environ['STREAMLIT_SERVER_HEADLESS'] = 'true'
# Minimal imports to avoid conflicts
try:
import tensorflow as tf
TF_AVAILABLE = True
except ImportError:
TF_AVAILABLE = False
st.error("TensorFlow not available")
try:
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend
import matplotlib.pyplot as plt
import matplotlib.cm as cm
MPL_AVAILABLE = True
except ImportError:
MPL_AVAILABLE = False
# Page config
st.set_page_config(
page_title="Stroke Classifier",
page_icon="🧠",
layout="wide")
# Simple styling
st.markdown("""
<style>
.main-header {
font-size: 2.5rem;
color: #1f77b4;
text-align: center;
margin-bottom: 2rem;
}
.prediction-box {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 2rem;
border-radius: 1rem;
text-align: center;
margin: 1rem 0;
}
.status-box {
padding: 1rem;
border-radius: 0.5rem;
margin: 1rem 0;
}
.success { background-color: #d4edda; border: 1px solid #c3e6cb; color: #155724; }
.error { background-color: #f8d7da; border: 1px solid #f5c6cb; color: #721c24; }
.info { background-color: #d1ecf1; border: 1px solid #bee5eb; color: #0c5460; }
.warning { background-color: #fff3cd; border: 1px solid #ffeaa7; color: #856404; }
.debug { background-color: #f8f9fa; border: 1px solid #dee2e6; color: #495057; font-family: monospace; }
</style>""", unsafe_allow_html=True)
# Initialize session state
if 'model_loaded' not in st.session_state:
st.session_state.model_loaded = False
st.session_state.model = None
st.session_state.model_status = "Not loaded"
STROKE_LABELS = ["Hemorrhagic Stroke", "Ischemic Stroke", "No Stroke"]
def find_model_file():
"""Find the model file in various possible locations."""
possible_paths = [
"stroke_classification_model.h5",
"./stroke_classification_model.h5",
"/app/stroke_classification_model.h5",
"src/stroke_classification_model.h5",
os.path.join(os.getcwd(), "stroke_classification_model.h5")
]
# Also check all .h5 files in current directory and subdirectories
for root, dirs, files in os.walk('.'):
for file in files:
if file.endswith('.h5'):
possible_paths.append(os.path.join(root, file))
for path in possible_paths:
if os.path.exists(path):
return path
return None
@st.cache_resource
def load_stroke_model():
"""Load model with caching."""
if not TF_AVAILABLE:
return None, "❌ TensorFlow not available"
try:
# Find the model file
model_path = find_model_file()
if model_path is None:
# List all files to help debug
current_files = []
for root, dirs, files in os.walk('.'):
for file in files:
current_files.append(os.path.join(root, file))
return None, f"❌ Model file not found. Available files: {current_files[:10]}"
st.info(f"Found model at: {model_path}")
# Load model with minimal custom objects
model = tf.keras.models.load_model(model_path, compile=False)
return model, f"βœ… Model loaded successfully from: {model_path}"
except Exception as e:
return None, f"❌ Model loading failed: {str(e)}"
def analyze_model_architecture(model):
"""Comprehensive analysis of model architecture."""
if model is None:
return {"error": "No model loaded"}
layer_analysis = {
'total_layers': len(model.layers),
'conv_layers': [],
'dense_layers': [],
'other_layers': [],
'all_layers_detailed': [],
'model_type': 'Unknown'
}
for i, layer in enumerate(model.layers):
layer_type = type(layer).__name__
# Get more detailed layer information
layer_info = {
'index': i,
'name': layer.name,
'type': layer_type,
'output_shape': getattr(layer, 'output_shape', 'Unknown'),
'trainable': getattr(layer, 'trainable', 'Unknown'),
'activation': getattr(layer, 'activation', None)
}
# Try to get activation function name
if hasattr(layer, 'activation') and layer.activation:
try:
layer_info['activation'] = layer.activation.__name__
except:
layer_info['activation'] = str(layer.activation)
layer_analysis['all_layers_detailed'].append(layer_info)
# Categorize layers with more comprehensive detection
if any(conv_type in layer_type for conv_type in [
'Conv1D', 'Conv2D', 'Conv3D', 'SeparableConv2D', 'DepthwiseConv2D',
'Convolution1D', 'Convolution2D', 'Convolution3D'
]) or 'conv' in layer.name.lower():
layer_analysis['conv_layers'].append(layer_info)
elif 'Dense' in layer_type or 'Linear' in layer_type:
layer_analysis['dense_layers'].append(layer_info)
else:
layer_analysis['other_layers'].append(layer_info)
# Determine model type
if layer_analysis['conv_layers']:
layer_analysis['model_type'] = 'CNN (Convolutional Neural Network)'
elif layer_analysis['dense_layers']:
layer_analysis['model_type'] = 'MLP (Multi-Layer Perceptron)'
else:
layer_analysis['model_type'] = 'Custom Architecture'
return layer_analysis
def debug_gradcam_step_by_step(img_array, model, layer_name, pred_index):
"""Debug Grad-CAM computation step by step."""
debug_info = {
'step': 'Starting',
'error': None,
'layer_output_shape': None,
'gradients_shape': None,
'gradients_stats': None,
'heatmap_stats': None
}
try:
debug_info['step'] = 'Getting target layer'
target_layer = model.get_layer(layer_name)
debug_info['target_layer_type'] = type(target_layer).__name__
debug_info['step'] = 'Creating grad model'
grad_model = tf.keras.Model(
inputs=[model.inputs],
outputs=[target_layer.output, model.output]
)
debug_info['step'] = 'Computing forward pass'
with tf.GradientTape() as tape:
layer_output, preds = grad_model(img_array)
debug_info['layer_output_shape'] = layer_output.shape.as_list()
debug_info['predictions_shape'] = preds.shape.as_list()
if pred_index is None:
pred_index = tf.argmax(preds[0])
debug_info['pred_index'] = int(pred_index)
debug_info['pred_confidence'] = float(preds[0][pred_index])
class_channel = preds[:, pred_index]
debug_info['class_channel_shape'] = class_channel.shape.as_list()
debug_info['step'] = 'Computing gradients'
grads = tape.gradient(class_channel, layer_output)
if grads is None:
debug_info['error'] = "Gradients are None - no backpropagation path"
return None, debug_info
debug_info['gradients_shape'] = grads.shape.as_list()
debug_info['gradients_stats'] = {
'min': float(tf.reduce_min(grads)),
'max': float(tf.reduce_max(grads)),
'mean': float(tf.reduce_mean(grads)),
'std': float(tf.math.reduce_std(grads))
}
debug_info['step'] = 'Processing gradients based on layer type'
if len(layer_output.shape) == 4: # Conv layer
debug_info['processing_type'] = 'Convolutional layer (4D)'
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
layer_output = layer_output[0]
heatmap = layer_output @ pooled_grads[..., tf.newaxis]
heatmap = tf.squeeze(heatmap)
elif len(layer_output.shape) == 2: # Dense layer
debug_info['processing_type'] = 'Dense layer (2D)'
# For dense layers, create spatial heatmap from gradient magnitude
grads_magnitude = tf.reduce_mean(tf.abs(grads))
# Create a simple spatial pattern
heatmap = tf.ones((14, 14)) * grads_magnitude
else:
debug_info['error'] = f"Unsupported layer shape: {layer_output.shape}"
return None, debug_info
debug_info['step'] = 'Normalizing heatmap'
debug_info['raw_heatmap_stats'] = {
'min': float(tf.reduce_min(heatmap)),
'max': float(tf.reduce_max(heatmap)),
'mean': float(tf.reduce_mean(heatmap)),
'std': float(tf.math.reduce_std(heatmap))
}
# Apply ReLU (remove negative values)
heatmap = tf.maximum(heatmap, 0)
# Normalize
heatmap_max = tf.reduce_max(heatmap)
if heatmap_max > 0:
heatmap = heatmap / heatmap_max
else:
debug_info['error'] = "All heatmap values are zero or negative"
return None, debug_info
debug_info['final_heatmap_stats'] = {
'min': float(tf.reduce_min(heatmap)),
'max': float(tf.reduce_max(heatmap)),
'mean': float(tf.reduce_mean(heatmap)),
'std': float(tf.math.reduce_std(heatmap))
}
debug_info['step'] = 'Complete'
return heatmap.numpy(), debug_info
except Exception as e:
debug_info['error'] = f"Exception in step '{debug_info['step']}': {str(e)}"
return None, debug_info
def create_robust_gradcam_heatmap(img, model, predictions):
"""Create Grad-CAM with comprehensive debugging."""
try:
# Preprocess image
img_resized = img.resize((224, 224))
img_array = np.array(img_resized, dtype=np.float32)
# Handle grayscale
if len(img_array.shape) == 2:
img_array = np.stack([img_array] * 3, axis=-1)
# Normalize and add batch dimension
img_array = np.expand_dims(img_array, axis=0) / 255.0
# Get model analysis
analysis = analyze_model_architecture(model)
# Try different layers in order of preference
layer_candidates = []
# Add conv layers first
for layer in analysis['conv_layers']:
layer_candidates.append((layer['name'], f"Conv layer: {layer['name']}"))
# Add other potentially suitable layers
for layer in analysis['all_layers_detailed']:
if (layer['type'] in ['Activation', 'BatchNormalization'] and
isinstance(layer['output_shape'], (list, tuple)) and
len(layer['output_shape']) == 4):
layer_candidates.append((layer['name'], f"4D layer: {layer['name']} ({layer['type']})"))
# Try dense layers as last resort
if not layer_candidates:
for layer in analysis['dense_layers']:
layer_candidates.append((layer['name'], f"Dense layer: {layer['name']} (experimental)"))
if not layer_candidates:
return None, "❌ No suitable layers found", None
# Try each candidate layer
for layer_name, layer_desc in layer_candidates:
pred_index = np.argmax(predictions)
heatmap, debug_info = debug_gradcam_step_by_step(
img_array, model, layer_name, pred_index
)
if heatmap is not None:
# Resize heatmap to match input image size
if heatmap.shape[0] != 224 or heatmap.shape[1] != 224:
heatmap_resized = tf.image.resize(
heatmap[..., tf.newaxis],
(224, 224)
).numpy()[:, :, 0]
else:
heatmap_resized = heatmap
# Final statistics
stats = {
'min': float(np.min(heatmap_resized)),
'max': float(np.max(heatmap_resized)),
'mean': float(np.mean(heatmap_resized)),
'std': float(np.std(heatmap_resized))
}
return heatmap_resized, f"βœ… Grad-CAM successful using {layer_desc}", stats, debug_info
else:
# Continue to next layer if this one failed
continue
# If all layers failed, return debug info from the last attempt
return None, f"❌ All layers failed. Last error: {debug_info.get('error', 'Unknown')}", None, debug_info
except Exception as e:
return None, f"❌ Grad-CAM error: {str(e)}", None, {'error': str(e)}
def predict_stroke(img, model):
"""Predict stroke type from image."""
if model is None:
return None, "Model not loaded"
try:
# Preprocess image
img_resized = img.resize((224, 224))
img_array = np.array(img_resized, dtype=np.float32)
# Handle grayscale
if len(img_array.shape) == 2:
img_array = np.stack([img_array] * 3, axis=-1)
# Normalize and add batch dimension
img_array = np.expand_dims(img_array, axis=0) / 255.0
# Predict
predictions = model.predict(img_array, verbose=0)
return predictions[0], None
except Exception as e:
return None, f"Prediction error: {str(e)}"
def create_enhanced_simulated_heatmap(img, predictions):
"""Create a more realistic simulated heatmap."""
try:
confidence = np.max(predictions)
predicted_class = np.argmax(predictions)
# Create different patterns based on predicted class
if predicted_class == 0: # Hemorrhagic
# Focus on center-left region
center_x, center_y = 80, 112
elif predicted_class == 1: # Ischemic
# Focus on right side
center_x, center_y = 150, 112
else: # No stroke
# Diffuse, low-intensity pattern
center_x, center_y = 112, 112
# Create base pattern
y, x = np.ogrid[:224, :224]
# Primary focus area
mask1 = np.exp(-((x - center_x)**2 + (y - center_y)**2) / (2 * (40**2)))
# Secondary areas
mask2 = np.exp(-((x - center_x + 30)**2 + (y - center_y + 20)**2) / (2 * (25**2)))
mask3 = np.exp(-((x - center_x - 20)**2 + (y - center_y - 30)**2) / (2 * (30**2)))
# Combine patterns
heatmap = (mask1 * 0.8 + mask2 * 0.4 + mask3 * 0.3) * confidence
# Add some noise for realism
np.random.seed(42)
noise = np.random.normal(0, 0.05, heatmap.shape)
heatmap = np.maximum(heatmap + noise, 0)
# Normalize
if np.max(heatmap) > 0:
heatmap = heatmap / np.max(heatmap)
stats = {
'min': float(np.min(heatmap)),
'max': float(np.max(heatmap)),
'mean': float(np.mean(heatmap)),
'std': float(np.std(heatmap))
}
return heatmap, "⚠️ Using enhanced simulated heatmap", stats
except Exception as e:
return None, f"❌ Simulated heatmap error: {str(e)}", None
def create_comprehensive_visualization(img, predictions, model, force_gradcam=True, colormap='hot'):
"""Create comprehensive visualization with debugging."""
if not MPL_AVAILABLE:
return None, "❌ Matplotlib not available"
try:
# Resize image to 224x224
img_resized = img.resize((224, 224))
img_array = np.array(img_resized)
heatmap = None
status_message = ""
stats = None
debug_info = None
# Try Grad-CAM first
if force_gradcam and model is not None:
result = create_robust_gradcam_heatmap(img, model, predictions)
if result and len(result) >= 3:
heatmap, gradcam_status, stats = result[0], result[1], result[2]
if len(result) > 3:
debug_info = result[3]
status_message = gradcam_status
# Fallback to enhanced simulated if Grad-CAM failed
if heatmap is None:
result = create_enhanced_simulated_heatmap(img, predictions)
if result and len(result) == 3:
heatmap, sim_status, stats = result
if status_message:
status_message += f" | {sim_status}"
else:
status_message = sim_status
if heatmap is None:
return None, "❌ Could not generate any heatmap", None, None
# Create visualization
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# 1. Original image
axes[0].imshow(img_array)
axes[0].set_title("Original Image", fontsize=12, fontweight='bold')
axes[0].axis('off')
# 2. Heatmap only
im1 = axes[1].imshow(heatmap, cmap=colormap, vmin=0, vmax=1)
axes[1].set_title(f"Attention Heatmap ({colormap})", fontsize=12, fontweight='bold')
axes[1].axis('off')
plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
# 3. Overlay
axes[2].imshow(img_array)
im2 = axes[2].imshow(heatmap, cmap=colormap, alpha=0.6, vmin=0, vmax=1, interpolation='bilinear')
# Determine title based on success
if "βœ… Grad-CAM successful" in status_message:
title = "🎯 Real AI Attention Overlay"
title_color = 'green'
else:
title = "🎨 Simulated Attention Overlay"
title_color = 'orange'
axes[2].set_title(title, fontsize=12, fontweight='bold', color=title_color)
axes[2].axis('off')
plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04)
plt.tight_layout()
return fig, status_message, stats, debug_info
except Exception as e:
return None, f"❌ Visualization error: {str(e)}", None, None
# Main App
def main():
# Header
st.markdown('<h1 class="main-header">🧠 AI-Powered Stroke Classification System</h1>', unsafe_allow_html=True)
# Auto-load model on startup
if not st.session_state.model_loaded:
with st.spinner("Loading AI model..."):
st.session_state.model, st.session_state.model_status = load_stroke_model()
st.session_state.model_loaded = True
# System status
st.markdown("### πŸ”§ System Status")
col1, col2, col3 = st.columns(3)
with col1:
if TF_AVAILABLE:
st.markdown('<div class="status-box success">βœ… TensorFlow Ready</div>', unsafe_allow_html=True)
st.write(f"TF Version: {tf.__version__}")
else:
st.markdown('<div class="status-box error">❌ TensorFlow Error</div>', unsafe_allow_html=True)
with col2:
if MPL_AVAILABLE:
st.markdown('<div class="status-box success">βœ… Matplotlib Ready</div>', unsafe_allow_html=True)
else:
st.markdown('<div class="status-box error">❌ Matplotlib Error</div>', unsafe_allow_html=True)
with col3:
if "βœ…" in st.session_state.model_status:
st.markdown('<div class="status-box success">βœ… Model Loaded</div>', unsafe_allow_html=True)
else:
st.markdown('<div class="status-box error">❌ Model Error</div>', unsafe_allow_html=True)
# Model status details
st.markdown(f'<div class="status-box info"><strong>Model Status:</strong> {st.session_state.model_status}</div>', unsafe_allow_html=True)
# Enhanced model architecture analysis
if st.session_state.model is not None:
with st.expander("πŸ” Detailed Model Architecture Analysis"):
analysis = analyze_model_architecture(st.session_state.model)
st.write("**πŸ“Š Model Summary:**")
st.write(f"- **Model Type:** {analysis['model_type']}")
st.write(f"- **Total Layers:** {analysis['total_layers']}")
st.write(f"- **Convolutional Layers:** {len(analysis['conv_layers'])}")
st.write(f"- **Dense Layers:** {len(analysis['dense_layers'])}")
st.write(f"- **Other Layers:** {len(analysis['other_layers'])}")
# Show detailed layer information
st.write("**πŸ” All Layers (Detailed):**")
for layer in analysis['all_layers_detailed']:
activation_info = f" | Activation: {layer['activation']}" if layer['activation'] else ""
st.code(f"{layer['index']:2d}: {layer['name']} ({layer['type']}) | Shape: {layer['output_shape']}{activation_info}")
# Manual reload button
if st.button("πŸ”„ Reload Model", help="Try to reload the model"):
st.session_state.model_loaded = False
st.rerun()
# Sidebar
with st.sidebar:
st.header("πŸ“€ Upload Brain Scan")
uploaded_file = st.file_uploader(
"Choose a brain scan image...",
type=['png', 'jpg', 'jpeg', 'bmp', 'tiff'],
help="Upload a brain scan image for stroke classification"
)
st.markdown("---")
st.header("🎨 Visualization Options")
force_gradcam = st.checkbox(
"Attempt Grad-CAM",
value=True,
help="Try Grad-CAM with comprehensive debugging"
)
colormap = st.selectbox(
"Color Scheme",
['hot', 'jet', 'viridis', 'plasma', 'inferno', 'magma', 'coolwarm'],
index=0,
help="Choose color scheme for heatmap visualization"
)
show_probabilities = st.checkbox("Show All Probabilities", value=True)
show_debug = st.checkbox("Show Debug Info", value=True)
show_stats = st.checkbox("Show Heatmap Statistics", value=True)
show_detailed_debug = st.checkbox("Show Detailed Debug Info", value=False)
if uploaded_file is not None:
# Load image
image = Image.open(uploaded_file)
# Main content area
col1, col2 = st.columns([1, 2])
with col1:
st.subheader("πŸ“‹ Classification Results")
if st.session_state.model is not None:
# Predict
with st.spinner("πŸ” Analyzing brain scan..."):
predictions, error = predict_stroke(image, st.session_state.model)
if error:
st.error(error)
else:
# Get top prediction
class_idx = np.argmax(predictions)
confidence = predictions[class_idx] * 100
predicted_class = STROKE_LABELS[class_idx]
# Display main result
st.markdown(f"""
<div class="prediction-box">
<h2>{predicted_class}</h2>
<h3>Confidence: {confidence:.1f}%</h3>
</div>
""", unsafe_allow_html=True)
# Show all probabilities
if show_probabilities:
st.write("**πŸ“Š All Probabilities:**")
for i, (label, prob) in enumerate(zip(STROKE_LABELS, predictions)):
st.write(f"β€’ {label}: {prob*100:.1f}%")
else:
st.error("❌ Model not loaded. Check the debug information above to see available files.")
with col2:
st.subheader("🎯 Comprehensive AI Attention Visualization")
if st.session_state.model is not None and 'predictions' in locals() and predictions is not None:
# Create comprehensive visualization
with st.spinner("🎨 Generating comprehensive attention visualization..."):
result = create_comprehensive_visualization(
image,
predictions,
st.session_state.model,
force_gradcam,
colormap
)
if result and len(result) >= 2:
overlay_fig, status_message = result[0], result[1]
stats = result[2] if len(result) > 2 else None
debug_info = result[3] if len(result) > 3 else None
if overlay_fig is not None:
st.pyplot(overlay_fig)
plt.close()
# Show detailed status
if show_debug:
if "βœ… Grad-CAM successful" in status_message:
st.success(f"βœ… {status_message}")
elif "⚠️" in status_message:
st.warning(f"⚠️ {status_message}")
else:
st.error(f"❌ {status_message}")
# Show heatmap statistics
if show_stats and stats:
st.write("**πŸ“ˆ Heatmap Statistics:**")
if any(np.isnan([stats['min'], stats['max'], stats['mean'], stats['std']])):
st.error("⚠️ NaN values detected in heatmap - this indicates a computation error")
else:
col_stats1, col_stats2 = st.columns(2)
with col_stats1:
st.write(f"β€’ Min: {stats['min']:.3f}")
st.write(f"β€’ Max: {stats['max']:.3f}")
with col_stats2:
st.write(f"β€’ Mean: {stats['mean']:.3f}")
st.write(f"β€’ Std: {stats['std']:.3f}")
# Show detailed debug information
if show_detailed_debug and debug_info:
with st.expander("πŸ”§ Detailed Debug Information"):
st.json(debug_info)
else:
st.error(f"Could not generate visualization: {status_message}")
if debug_info:
st.error(f"Debug info: {debug_info.get('error', 'No additional info')}")
else:
st.error("Could not generate attention visualization")
else:
st.info("Upload an image and run classification to see AI attention visualization")
else:
# Welcome message
st.markdown("""
## πŸ‘‹ Welcome to the Comprehensive Stroke Classification System
This system now includes **step-by-step debugging** to identify why Grad-CAM might be failing.
### πŸ”§ New Debugging Features:
- **Step-by-step Grad-CAM debugging** - See exactly where it fails
- **Multiple layer attempts** - Tries different layers automatically
- **Enhanced error messages** - Clear explanations of what went wrong
- **NaN detection** - Identifies computation errors
### 🎯 What to Look For:
- **Green success messages** - Grad-CAM is working
- **Orange warnings** - Using fallback methods
- **Red errors** - Something is broken
- **NaN statistics** - Computation failure
**Upload an image to see detailed debugging! πŸ‘ˆ**
""")
# Medical disclaimer
st.markdown("---")
st.warning("⚠️ **Medical Disclaimer:** This AI system is for educational and research purposes only. It should not be used for actual medical diagnosis. Always consult qualified healthcare professionals for medical decisions.")
if __name__ == "__main__":
main()