import streamlit as st import numpy as np import os import sys from PIL import Image # Set environment variables to fix permission issues os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib' os.environ['STREAMLIT_SERVER_HEADLESS'] = 'true' # Minimal imports to avoid conflicts try: import tensorflow as tf TF_AVAILABLE = True except ImportError: TF_AVAILABLE = False st.error("TensorFlow not available") try: import matplotlib matplotlib.use('Agg') # Use non-interactive backend import matplotlib.pyplot as plt import matplotlib.cm as cm MPL_AVAILABLE = True except ImportError: MPL_AVAILABLE = False # Page config st.set_page_config( page_title="Stroke Classifier", page_icon="🧠", layout="wide") # Simple styling st.markdown(""" """, unsafe_allow_html=True) # Initialize session state if 'model_loaded' not in st.session_state: st.session_state.model_loaded = False st.session_state.model = None st.session_state.model_status = "Not loaded" STROKE_LABELS = ["Hemorrhagic Stroke", "Ischemic Stroke", "No Stroke"] def find_model_file(): """Find the model file in various possible locations.""" possible_paths = [ "stroke_classification_model.h5", "./stroke_classification_model.h5", "/app/stroke_classification_model.h5", "src/stroke_classification_model.h5", os.path.join(os.getcwd(), "stroke_classification_model.h5") ] # Also check all .h5 files in current directory and subdirectories for root, dirs, files in os.walk('.'): for file in files: if file.endswith('.h5'): possible_paths.append(os.path.join(root, file)) for path in possible_paths: if os.path.exists(path): return path return None @st.cache_resource def load_stroke_model(): """Load model with caching.""" if not TF_AVAILABLE: return None, "❌ TensorFlow not available" try: # Find the model file model_path = find_model_file() if model_path is None: # List all files to help debug current_files = [] for root, dirs, files in os.walk('.'): for file in files: current_files.append(os.path.join(root, file)) return None, f"❌ Model file not found. Available files: {current_files[:10]}" st.info(f"Found model at: {model_path}") # Load model with minimal custom objects model = tf.keras.models.load_model(model_path, compile=False) return model, f"✅ Model loaded successfully from: {model_path}" except Exception as e: return None, f"❌ Model loading failed: {str(e)}" def analyze_model_architecture(model): """Comprehensive analysis of model architecture.""" if model is None: return {"error": "No model loaded"} layer_analysis = { 'total_layers': len(model.layers), 'conv_layers': [], 'dense_layers': [], 'other_layers': [], 'all_layers_detailed': [], 'model_type': 'Unknown' } for i, layer in enumerate(model.layers): layer_type = type(layer).__name__ # Get more detailed layer information layer_info = { 'index': i, 'name': layer.name, 'type': layer_type, 'output_shape': getattr(layer, 'output_shape', 'Unknown'), 'trainable': getattr(layer, 'trainable', 'Unknown'), 'activation': getattr(layer, 'activation', None) } # Try to get activation function name if hasattr(layer, 'activation') and layer.activation: try: layer_info['activation'] = layer.activation.__name__ except: layer_info['activation'] = str(layer.activation) layer_analysis['all_layers_detailed'].append(layer_info) # Categorize layers with more comprehensive detection if any(conv_type in layer_type for conv_type in [ 'Conv1D', 'Conv2D', 'Conv3D', 'SeparableConv2D', 'DepthwiseConv2D', 'Convolution1D', 'Convolution2D', 'Convolution3D' ]) or 'conv' in layer.name.lower(): layer_analysis['conv_layers'].append(layer_info) elif 'Dense' in layer_type or 'Linear' in layer_type: layer_analysis['dense_layers'].append(layer_info) else: layer_analysis['other_layers'].append(layer_info) # Determine model type if layer_analysis['conv_layers']: layer_analysis['model_type'] = 'CNN (Convolutional Neural Network)' elif layer_analysis['dense_layers']: layer_analysis['model_type'] = 'MLP (Multi-Layer Perceptron)' else: layer_analysis['model_type'] = 'Custom Architecture' return layer_analysis def debug_gradcam_step_by_step(img_array, model, layer_name, pred_index): """Debug Grad-CAM computation step by step.""" debug_info = { 'step': 'Starting', 'error': None, 'layer_output_shape': None, 'gradients_shape': None, 'gradients_stats': None, 'heatmap_stats': None } try: debug_info['step'] = 'Getting target layer' target_layer = model.get_layer(layer_name) debug_info['target_layer_type'] = type(target_layer).__name__ debug_info['step'] = 'Creating grad model' grad_model = tf.keras.Model( inputs=[model.inputs], outputs=[target_layer.output, model.output] ) debug_info['step'] = 'Computing forward pass' with tf.GradientTape() as tape: layer_output, preds = grad_model(img_array) debug_info['layer_output_shape'] = layer_output.shape.as_list() debug_info['predictions_shape'] = preds.shape.as_list() if pred_index is None: pred_index = tf.argmax(preds[0]) debug_info['pred_index'] = int(pred_index) debug_info['pred_confidence'] = float(preds[0][pred_index]) class_channel = preds[:, pred_index] debug_info['class_channel_shape'] = class_channel.shape.as_list() debug_info['step'] = 'Computing gradients' grads = tape.gradient(class_channel, layer_output) if grads is None: debug_info['error'] = "Gradients are None - no backpropagation path" return None, debug_info debug_info['gradients_shape'] = grads.shape.as_list() debug_info['gradients_stats'] = { 'min': float(tf.reduce_min(grads)), 'max': float(tf.reduce_max(grads)), 'mean': float(tf.reduce_mean(grads)), 'std': float(tf.math.reduce_std(grads)) } debug_info['step'] = 'Processing gradients based on layer type' if len(layer_output.shape) == 4: # Conv layer debug_info['processing_type'] = 'Convolutional layer (4D)' pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) layer_output = layer_output[0] heatmap = layer_output @ pooled_grads[..., tf.newaxis] heatmap = tf.squeeze(heatmap) elif len(layer_output.shape) == 2: # Dense layer debug_info['processing_type'] = 'Dense layer (2D)' # For dense layers, create spatial heatmap from gradient magnitude grads_magnitude = tf.reduce_mean(tf.abs(grads)) # Create a simple spatial pattern heatmap = tf.ones((14, 14)) * grads_magnitude else: debug_info['error'] = f"Unsupported layer shape: {layer_output.shape}" return None, debug_info debug_info['step'] = 'Normalizing heatmap' debug_info['raw_heatmap_stats'] = { 'min': float(tf.reduce_min(heatmap)), 'max': float(tf.reduce_max(heatmap)), 'mean': float(tf.reduce_mean(heatmap)), 'std': float(tf.math.reduce_std(heatmap)) } # Apply ReLU (remove negative values) heatmap = tf.maximum(heatmap, 0) # Normalize heatmap_max = tf.reduce_max(heatmap) if heatmap_max > 0: heatmap = heatmap / heatmap_max else: debug_info['error'] = "All heatmap values are zero or negative" return None, debug_info debug_info['final_heatmap_stats'] = { 'min': float(tf.reduce_min(heatmap)), 'max': float(tf.reduce_max(heatmap)), 'mean': float(tf.reduce_mean(heatmap)), 'std': float(tf.math.reduce_std(heatmap)) } debug_info['step'] = 'Complete' return heatmap.numpy(), debug_info except Exception as e: debug_info['error'] = f"Exception in step '{debug_info['step']}': {str(e)}" return None, debug_info def create_robust_gradcam_heatmap(img, model, predictions): """Create Grad-CAM with comprehensive debugging.""" try: # Preprocess image img_resized = img.resize((224, 224)) img_array = np.array(img_resized, dtype=np.float32) # Handle grayscale if len(img_array.shape) == 2: img_array = np.stack([img_array] * 3, axis=-1) # Normalize and add batch dimension img_array = np.expand_dims(img_array, axis=0) / 255.0 # Get model analysis analysis = analyze_model_architecture(model) # Try different layers in order of preference layer_candidates = [] # Add conv layers first for layer in analysis['conv_layers']: layer_candidates.append((layer['name'], f"Conv layer: {layer['name']}")) # Add other potentially suitable layers for layer in analysis['all_layers_detailed']: if (layer['type'] in ['Activation', 'BatchNormalization'] and isinstance(layer['output_shape'], (list, tuple)) and len(layer['output_shape']) == 4): layer_candidates.append((layer['name'], f"4D layer: {layer['name']} ({layer['type']})")) # Try dense layers as last resort if not layer_candidates: for layer in analysis['dense_layers']: layer_candidates.append((layer['name'], f"Dense layer: {layer['name']} (experimental)")) if not layer_candidates: return None, "❌ No suitable layers found", None # Try each candidate layer for layer_name, layer_desc in layer_candidates: pred_index = np.argmax(predictions) heatmap, debug_info = debug_gradcam_step_by_step( img_array, model, layer_name, pred_index ) if heatmap is not None: # Resize heatmap to match input image size if heatmap.shape[0] != 224 or heatmap.shape[1] != 224: heatmap_resized = tf.image.resize( heatmap[..., tf.newaxis], (224, 224) ).numpy()[:, :, 0] else: heatmap_resized = heatmap # Final statistics stats = { 'min': float(np.min(heatmap_resized)), 'max': float(np.max(heatmap_resized)), 'mean': float(np.mean(heatmap_resized)), 'std': float(np.std(heatmap_resized)) } return heatmap_resized, f"✅ Grad-CAM successful using {layer_desc}", stats, debug_info else: # Continue to next layer if this one failed continue # If all layers failed, return debug info from the last attempt return None, f"❌ All layers failed. Last error: {debug_info.get('error', 'Unknown')}", None, debug_info except Exception as e: return None, f"❌ Grad-CAM error: {str(e)}", None, {'error': str(e)} def predict_stroke(img, model): """Predict stroke type from image.""" if model is None: return None, "Model not loaded" try: # Preprocess image img_resized = img.resize((224, 224)) img_array = np.array(img_resized, dtype=np.float32) # Handle grayscale if len(img_array.shape) == 2: img_array = np.stack([img_array] * 3, axis=-1) # Normalize and add batch dimension img_array = np.expand_dims(img_array, axis=0) / 255.0 # Predict predictions = model.predict(img_array, verbose=0) return predictions[0], None except Exception as e: return None, f"Prediction error: {str(e)}" def create_enhanced_simulated_heatmap(img, predictions): """Create a more realistic simulated heatmap.""" try: confidence = np.max(predictions) predicted_class = np.argmax(predictions) # Create different patterns based on predicted class if predicted_class == 0: # Hemorrhagic # Focus on center-left region center_x, center_y = 80, 112 elif predicted_class == 1: # Ischemic # Focus on right side center_x, center_y = 150, 112 else: # No stroke # Diffuse, low-intensity pattern center_x, center_y = 112, 112 # Create base pattern y, x = np.ogrid[:224, :224] # Primary focus area mask1 = np.exp(-((x - center_x)**2 + (y - center_y)**2) / (2 * (40**2))) # Secondary areas mask2 = np.exp(-((x - center_x + 30)**2 + (y - center_y + 20)**2) / (2 * (25**2))) mask3 = np.exp(-((x - center_x - 20)**2 + (y - center_y - 30)**2) / (2 * (30**2))) # Combine patterns heatmap = (mask1 * 0.8 + mask2 * 0.4 + mask3 * 0.3) * confidence # Add some noise for realism np.random.seed(42) noise = np.random.normal(0, 0.05, heatmap.shape) heatmap = np.maximum(heatmap + noise, 0) # Normalize if np.max(heatmap) > 0: heatmap = heatmap / np.max(heatmap) stats = { 'min': float(np.min(heatmap)), 'max': float(np.max(heatmap)), 'mean': float(np.mean(heatmap)), 'std': float(np.std(heatmap)) } return heatmap, "⚠️ Using enhanced simulated heatmap", stats except Exception as e: return None, f"❌ Simulated heatmap error: {str(e)}", None def create_comprehensive_visualization(img, predictions, model, force_gradcam=True, colormap='hot'): """Create comprehensive visualization with debugging.""" if not MPL_AVAILABLE: return None, "❌ Matplotlib not available" try: # Resize image to 224x224 img_resized = img.resize((224, 224)) img_array = np.array(img_resized) heatmap = None status_message = "" stats = None debug_info = None # Try Grad-CAM first if force_gradcam and model is not None: result = create_robust_gradcam_heatmap(img, model, predictions) if result and len(result) >= 3: heatmap, gradcam_status, stats = result[0], result[1], result[2] if len(result) > 3: debug_info = result[3] status_message = gradcam_status # Fallback to enhanced simulated if Grad-CAM failed if heatmap is None: result = create_enhanced_simulated_heatmap(img, predictions) if result and len(result) == 3: heatmap, sim_status, stats = result if status_message: status_message += f" | {sim_status}" else: status_message = sim_status if heatmap is None: return None, "❌ Could not generate any heatmap", None, None # Create visualization fig, axes = plt.subplots(1, 3, figsize=(15, 5)) # 1. Original image axes[0].imshow(img_array) axes[0].set_title("Original Image", fontsize=12, fontweight='bold') axes[0].axis('off') # 2. Heatmap only im1 = axes[1].imshow(heatmap, cmap=colormap, vmin=0, vmax=1) axes[1].set_title(f"Attention Heatmap ({colormap})", fontsize=12, fontweight='bold') axes[1].axis('off') plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04) # 3. Overlay axes[2].imshow(img_array) im2 = axes[2].imshow(heatmap, cmap=colormap, alpha=0.6, vmin=0, vmax=1, interpolation='bilinear') # Determine title based on success if "✅ Grad-CAM successful" in status_message: title = "🎯 Real AI Attention Overlay" title_color = 'green' else: title = "🎨 Simulated Attention Overlay" title_color = 'orange' axes[2].set_title(title, fontsize=12, fontweight='bold', color=title_color) axes[2].axis('off') plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04) plt.tight_layout() return fig, status_message, stats, debug_info except Exception as e: return None, f"❌ Visualization error: {str(e)}", None, None # Main App def main(): # Header st.markdown('

🧠 AI-Powered Stroke Classification System

', unsafe_allow_html=True) # Auto-load model on startup if not st.session_state.model_loaded: with st.spinner("Loading AI model..."): st.session_state.model, st.session_state.model_status = load_stroke_model() st.session_state.model_loaded = True # System status st.markdown("### 🔧 System Status") col1, col2, col3 = st.columns(3) with col1: if TF_AVAILABLE: st.markdown('
✅ TensorFlow Ready
', unsafe_allow_html=True) st.write(f"TF Version: {tf.__version__}") else: st.markdown('
❌ TensorFlow Error
', unsafe_allow_html=True) with col2: if MPL_AVAILABLE: st.markdown('
✅ Matplotlib Ready
', unsafe_allow_html=True) else: st.markdown('
❌ Matplotlib Error
', unsafe_allow_html=True) with col3: if "✅" in st.session_state.model_status: st.markdown('
✅ Model Loaded
', unsafe_allow_html=True) else: st.markdown('
❌ Model Error
', unsafe_allow_html=True) # Model status details st.markdown(f'
Model Status: {st.session_state.model_status}
', unsafe_allow_html=True) # Enhanced model architecture analysis if st.session_state.model is not None: with st.expander("🔍 Detailed Model Architecture Analysis"): analysis = analyze_model_architecture(st.session_state.model) st.write("**📊 Model Summary:**") st.write(f"- **Model Type:** {analysis['model_type']}") st.write(f"- **Total Layers:** {analysis['total_layers']}") st.write(f"- **Convolutional Layers:** {len(analysis['conv_layers'])}") st.write(f"- **Dense Layers:** {len(analysis['dense_layers'])}") st.write(f"- **Other Layers:** {len(analysis['other_layers'])}") # Show detailed layer information st.write("**🔍 All Layers (Detailed):**") for layer in analysis['all_layers_detailed']: activation_info = f" | Activation: {layer['activation']}" if layer['activation'] else "" st.code(f"{layer['index']:2d}: {layer['name']} ({layer['type']}) | Shape: {layer['output_shape']}{activation_info}") # Manual reload button if st.button("🔄 Reload Model", help="Try to reload the model"): st.session_state.model_loaded = False st.rerun() # Sidebar with st.sidebar: st.header("📤 Upload Brain Scan") uploaded_file = st.file_uploader( "Choose a brain scan image...", type=['png', 'jpg', 'jpeg', 'bmp', 'tiff'], help="Upload a brain scan image for stroke classification" ) st.markdown("---") st.header("🎨 Visualization Options") force_gradcam = st.checkbox( "Attempt Grad-CAM", value=True, help="Try Grad-CAM with comprehensive debugging" ) colormap = st.selectbox( "Color Scheme", ['hot', 'jet', 'viridis', 'plasma', 'inferno', 'magma', 'coolwarm'], index=0, help="Choose color scheme for heatmap visualization" ) show_probabilities = st.checkbox("Show All Probabilities", value=True) show_debug = st.checkbox("Show Debug Info", value=True) show_stats = st.checkbox("Show Heatmap Statistics", value=True) show_detailed_debug = st.checkbox("Show Detailed Debug Info", value=False) if uploaded_file is not None: # Load image image = Image.open(uploaded_file) # Main content area col1, col2 = st.columns([1, 2]) with col1: st.subheader("📋 Classification Results") if st.session_state.model is not None: # Predict with st.spinner("🔍 Analyzing brain scan..."): predictions, error = predict_stroke(image, st.session_state.model) if error: st.error(error) else: # Get top prediction class_idx = np.argmax(predictions) confidence = predictions[class_idx] * 100 predicted_class = STROKE_LABELS[class_idx] # Display main result st.markdown(f"""

{predicted_class}

Confidence: {confidence:.1f}%

""", unsafe_allow_html=True) # Show all probabilities if show_probabilities: st.write("**📊 All Probabilities:**") for i, (label, prob) in enumerate(zip(STROKE_LABELS, predictions)): st.write(f"• {label}: {prob*100:.1f}%") else: st.error("❌ Model not loaded. Check the debug information above to see available files.") with col2: st.subheader("🎯 Comprehensive AI Attention Visualization") if st.session_state.model is not None and 'predictions' in locals() and predictions is not None: # Create comprehensive visualization with st.spinner("🎨 Generating comprehensive attention visualization..."): result = create_comprehensive_visualization( image, predictions, st.session_state.model, force_gradcam, colormap ) if result and len(result) >= 2: overlay_fig, status_message = result[0], result[1] stats = result[2] if len(result) > 2 else None debug_info = result[3] if len(result) > 3 else None if overlay_fig is not None: st.pyplot(overlay_fig) plt.close() # Show detailed status if show_debug: if "✅ Grad-CAM successful" in status_message: st.success(f"✅ {status_message}") elif "⚠️" in status_message: st.warning(f"⚠️ {status_message}") else: st.error(f"❌ {status_message}") # Show heatmap statistics if show_stats and stats: st.write("**📈 Heatmap Statistics:**") if any(np.isnan([stats['min'], stats['max'], stats['mean'], stats['std']])): st.error("⚠️ NaN values detected in heatmap - this indicates a computation error") else: col_stats1, col_stats2 = st.columns(2) with col_stats1: st.write(f"• Min: {stats['min']:.3f}") st.write(f"• Max: {stats['max']:.3f}") with col_stats2: st.write(f"• Mean: {stats['mean']:.3f}") st.write(f"• Std: {stats['std']:.3f}") # Show detailed debug information if show_detailed_debug and debug_info: with st.expander("🔧 Detailed Debug Information"): st.json(debug_info) else: st.error(f"Could not generate visualization: {status_message}") if debug_info: st.error(f"Debug info: {debug_info.get('error', 'No additional info')}") else: st.error("Could not generate attention visualization") else: st.info("Upload an image and run classification to see AI attention visualization") else: # Welcome message st.markdown(""" ## 👋 Welcome to the Comprehensive Stroke Classification System This system now includes **step-by-step debugging** to identify why Grad-CAM might be failing. ### 🔧 New Debugging Features: - **Step-by-step Grad-CAM debugging** - See exactly where it fails - **Multiple layer attempts** - Tries different layers automatically - **Enhanced error messages** - Clear explanations of what went wrong - **NaN detection** - Identifies computation errors ### 🎯 What to Look For: - **Green success messages** - Grad-CAM is working - **Orange warnings** - Using fallback methods - **Red errors** - Something is broken - **NaN statistics** - Computation failure **Upload an image to see detailed debugging! 👈** """) # Medical disclaimer st.markdown("---") st.warning("⚠️ **Medical Disclaimer:** This AI system is for educational and research purposes only. It should not be used for actual medical diagnosis. Always consult qualified healthcare professionals for medical decisions.") if __name__ == "__main__": main()