import streamlit as st import numpy as np import os import sys from PIL import Image # Set environment variables to fix permission issues os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib' os.environ['STREAMLIT_SERVER_HEADLESS'] = 'true' # Minimal imports to avoid conflicts try: import tensorflow as tf TF_AVAILABLE = True except ImportError: TF_AVAILABLE = False st.error("TensorFlow not available") try: import matplotlib matplotlib.use('Agg') # Use non-interactive backend import matplotlib.pyplot as plt import matplotlib.cm as cm MPL_AVAILABLE = True except ImportError: MPL_AVAILABLE = False # Page config st.set_page_config( page_title="Stroke Classifier", page_icon="🧠", layout="wide") # Simple styling st.markdown(""" """, unsafe_allow_html=True) # Initialize session state if 'model_loaded' not in st.session_state: st.session_state.model_loaded = False st.session_state.model = None st.session_state.model_status = "Not loaded" STROKE_LABELS = ["Hemorrhagic Stroke", "Ischemic Stroke", "No Stroke"] def find_model_file(): """Find the model file in various possible locations.""" possible_paths = [ "stroke_classification_model.h5", "./stroke_classification_model.h5", "/app/stroke_classification_model.h5", "src/stroke_classification_model.h5", os.path.join(os.getcwd(), "stroke_classification_model.h5") ] # Also check all .h5 files in current directory and subdirectories for root, dirs, files in os.walk('.'): for file in files: if file.endswith('.h5'): possible_paths.append(os.path.join(root, file)) for path in possible_paths: if os.path.exists(path): return path return None @st.cache_resource def load_stroke_model(): """Load model with caching.""" if not TF_AVAILABLE: return None, "❌ TensorFlow not available" try: # Find the model file model_path = find_model_file() if model_path is None: # List all files to help debug current_files = [] for root, dirs, files in os.walk('.'): for file in files: current_files.append(os.path.join(root, file)) return None, f"❌ Model file not found. Available files: {current_files[:10]}" st.info(f"Found model at: {model_path}") # Load model with minimal custom objects model = tf.keras.models.load_model(model_path, compile=False) return model, f"✅ Model loaded successfully from: {model_path}" except Exception as e: return None, f"❌ Model loading failed: {str(e)}" def analyze_model_architecture(model): """Comprehensive analysis of model architecture.""" if model is None: return {"error": "No model loaded"} layer_analysis = { 'total_layers': len(model.layers), 'conv_layers': [], 'dense_layers': [], 'other_layers': [], 'all_layers_detailed': [], 'model_type': 'Unknown' } for i, layer in enumerate(model.layers): layer_type = type(layer).__name__ # Get more detailed layer information layer_info = { 'index': i, 'name': layer.name, 'type': layer_type, 'output_shape': getattr(layer, 'output_shape', 'Unknown'), 'trainable': getattr(layer, 'trainable', 'Unknown'), 'activation': getattr(layer, 'activation', None) } # Try to get activation function name if hasattr(layer, 'activation') and layer.activation: try: layer_info['activation'] = layer.activation.__name__ except: layer_info['activation'] = str(layer.activation) layer_analysis['all_layers_detailed'].append(layer_info) # Categorize layers with more comprehensive detection if any(conv_type in layer_type for conv_type in [ 'Conv1D', 'Conv2D', 'Conv3D', 'SeparableConv2D', 'DepthwiseConv2D', 'Convolution1D', 'Convolution2D', 'Convolution3D' ]) or 'conv' in layer.name.lower(): layer_analysis['conv_layers'].append(layer_info) elif 'Dense' in layer_type or 'Linear' in layer_type: layer_analysis['dense_layers'].append(layer_info) else: layer_analysis['other_layers'].append(layer_info) # Determine model type if layer_analysis['conv_layers']: layer_analysis['model_type'] = 'CNN (Convolutional Neural Network)' elif layer_analysis['dense_layers']: layer_analysis['model_type'] = 'MLP (Multi-Layer Perceptron)' else: layer_analysis['model_type'] = 'Custom Architecture' return layer_analysis def debug_gradcam_step_by_step(img_array, model, layer_name, pred_index): """Debug Grad-CAM computation step by step.""" debug_info = { 'step': 'Starting', 'error': None, 'layer_output_shape': None, 'gradients_shape': None, 'gradients_stats': None, 'heatmap_stats': None } try: debug_info['step'] = 'Getting target layer' target_layer = model.get_layer(layer_name) debug_info['target_layer_type'] = type(target_layer).__name__ debug_info['step'] = 'Creating grad model' grad_model = tf.keras.Model( inputs=[model.inputs], outputs=[target_layer.output, model.output] ) debug_info['step'] = 'Computing forward pass' with tf.GradientTape() as tape: layer_output, preds = grad_model(img_array) debug_info['layer_output_shape'] = layer_output.shape.as_list() debug_info['predictions_shape'] = preds.shape.as_list() if pred_index is None: pred_index = tf.argmax(preds[0]) debug_info['pred_index'] = int(pred_index) debug_info['pred_confidence'] = float(preds[0][pred_index]) class_channel = preds[:, pred_index] debug_info['class_channel_shape'] = class_channel.shape.as_list() debug_info['step'] = 'Computing gradients' grads = tape.gradient(class_channel, layer_output) if grads is None: debug_info['error'] = "Gradients are None - no backpropagation path" return None, debug_info debug_info['gradients_shape'] = grads.shape.as_list() debug_info['gradients_stats'] = { 'min': float(tf.reduce_min(grads)), 'max': float(tf.reduce_max(grads)), 'mean': float(tf.reduce_mean(grads)), 'std': float(tf.math.reduce_std(grads)) } debug_info['step'] = 'Processing gradients based on layer type' if len(layer_output.shape) == 4: # Conv layer debug_info['processing_type'] = 'Convolutional layer (4D)' pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) layer_output = layer_output[0] heatmap = layer_output @ pooled_grads[..., tf.newaxis] heatmap = tf.squeeze(heatmap) elif len(layer_output.shape) == 2: # Dense layer debug_info['processing_type'] = 'Dense layer (2D)' # For dense layers, create spatial heatmap from gradient magnitude grads_magnitude = tf.reduce_mean(tf.abs(grads)) # Create a simple spatial pattern heatmap = tf.ones((14, 14)) * grads_magnitude else: debug_info['error'] = f"Unsupported layer shape: {layer_output.shape}" return None, debug_info debug_info['step'] = 'Normalizing heatmap' debug_info['raw_heatmap_stats'] = { 'min': float(tf.reduce_min(heatmap)), 'max': float(tf.reduce_max(heatmap)), 'mean': float(tf.reduce_mean(heatmap)), 'std': float(tf.math.reduce_std(heatmap)) } # Apply ReLU (remove negative values) heatmap = tf.maximum(heatmap, 0) # Normalize heatmap_max = tf.reduce_max(heatmap) if heatmap_max > 0: heatmap = heatmap / heatmap_max else: debug_info['error'] = "All heatmap values are zero or negative" return None, debug_info debug_info['final_heatmap_stats'] = { 'min': float(tf.reduce_min(heatmap)), 'max': float(tf.reduce_max(heatmap)), 'mean': float(tf.reduce_mean(heatmap)), 'std': float(tf.math.reduce_std(heatmap)) } debug_info['step'] = 'Complete' return heatmap.numpy(), debug_info except Exception as e: debug_info['error'] = f"Exception in step '{debug_info['step']}': {str(e)}" return None, debug_info def create_robust_gradcam_heatmap(img, model, predictions): """Create Grad-CAM with comprehensive debugging.""" try: # Preprocess image img_resized = img.resize((224, 224)) img_array = np.array(img_resized, dtype=np.float32) # Handle grayscale if len(img_array.shape) == 2: img_array = np.stack([img_array] * 3, axis=-1) # Normalize and add batch dimension img_array = np.expand_dims(img_array, axis=0) / 255.0 # Get model analysis analysis = analyze_model_architecture(model) # Try different layers in order of preference layer_candidates = [] # Add conv layers first for layer in analysis['conv_layers']: layer_candidates.append((layer['name'], f"Conv layer: {layer['name']}")) # Add other potentially suitable layers for layer in analysis['all_layers_detailed']: if (layer['type'] in ['Activation', 'BatchNormalization'] and isinstance(layer['output_shape'], (list, tuple)) and len(layer['output_shape']) == 4): layer_candidates.append((layer['name'], f"4D layer: {layer['name']} ({layer['type']})")) # Try dense layers as last resort if not layer_candidates: for layer in analysis['dense_layers']: layer_candidates.append((layer['name'], f"Dense layer: {layer['name']} (experimental)")) if not layer_candidates: return None, "❌ No suitable layers found", None # Try each candidate layer for layer_name, layer_desc in layer_candidates: pred_index = np.argmax(predictions) heatmap, debug_info = debug_gradcam_step_by_step( img_array, model, layer_name, pred_index ) if heatmap is not None: # Resize heatmap to match input image size if heatmap.shape[0] != 224 or heatmap.shape[1] != 224: heatmap_resized = tf.image.resize( heatmap[..., tf.newaxis], (224, 224) ).numpy()[:, :, 0] else: heatmap_resized = heatmap # Final statistics stats = { 'min': float(np.min(heatmap_resized)), 'max': float(np.max(heatmap_resized)), 'mean': float(np.mean(heatmap_resized)), 'std': float(np.std(heatmap_resized)) } return heatmap_resized, f"✅ Grad-CAM successful using {layer_desc}", stats, debug_info else: # Continue to next layer if this one failed continue # If all layers failed, return debug info from the last attempt return None, f"❌ All layers failed. Last error: {debug_info.get('error', 'Unknown')}", None, debug_info except Exception as e: return None, f"❌ Grad-CAM error: {str(e)}", None, {'error': str(e)} def predict_stroke(img, model): """Predict stroke type from image.""" if model is None: return None, "Model not loaded" try: # Preprocess image img_resized = img.resize((224, 224)) img_array = np.array(img_resized, dtype=np.float32) # Handle grayscale if len(img_array.shape) == 2: img_array = np.stack([img_array] * 3, axis=-1) # Normalize and add batch dimension img_array = np.expand_dims(img_array, axis=0) / 255.0 # Predict predictions = model.predict(img_array, verbose=0) return predictions[0], None except Exception as e: return None, f"Prediction error: {str(e)}" def create_enhanced_simulated_heatmap(img, predictions): """Create a more realistic simulated heatmap.""" try: confidence = np.max(predictions) predicted_class = np.argmax(predictions) # Create different patterns based on predicted class if predicted_class == 0: # Hemorrhagic # Focus on center-left region center_x, center_y = 80, 112 elif predicted_class == 1: # Ischemic # Focus on right side center_x, center_y = 150, 112 else: # No stroke # Diffuse, low-intensity pattern center_x, center_y = 112, 112 # Create base pattern y, x = np.ogrid[:224, :224] # Primary focus area mask1 = np.exp(-((x - center_x)**2 + (y - center_y)**2) / (2 * (40**2))) # Secondary areas mask2 = np.exp(-((x - center_x + 30)**2 + (y - center_y + 20)**2) / (2 * (25**2))) mask3 = np.exp(-((x - center_x - 20)**2 + (y - center_y - 30)**2) / (2 * (30**2))) # Combine patterns heatmap = (mask1 * 0.8 + mask2 * 0.4 + mask3 * 0.3) * confidence # Add some noise for realism np.random.seed(42) noise = np.random.normal(0, 0.05, heatmap.shape) heatmap = np.maximum(heatmap + noise, 0) # Normalize if np.max(heatmap) > 0: heatmap = heatmap / np.max(heatmap) stats = { 'min': float(np.min(heatmap)), 'max': float(np.max(heatmap)), 'mean': float(np.mean(heatmap)), 'std': float(np.std(heatmap)) } return heatmap, "⚠️ Using enhanced simulated heatmap", stats except Exception as e: return None, f"❌ Simulated heatmap error: {str(e)}", None def create_comprehensive_visualization(img, predictions, model, force_gradcam=True, colormap='hot'): """Create comprehensive visualization with debugging.""" if not MPL_AVAILABLE: return None, "❌ Matplotlib not available" try: # Resize image to 224x224 img_resized = img.resize((224, 224)) img_array = np.array(img_resized) heatmap = None status_message = "" stats = None debug_info = None # Try Grad-CAM first if force_gradcam and model is not None: result = create_robust_gradcam_heatmap(img, model, predictions) if result and len(result) >= 3: heatmap, gradcam_status, stats = result[0], result[1], result[2] if len(result) > 3: debug_info = result[3] status_message = gradcam_status # Fallback to enhanced simulated if Grad-CAM failed if heatmap is None: result = create_enhanced_simulated_heatmap(img, predictions) if result and len(result) == 3: heatmap, sim_status, stats = result if status_message: status_message += f" | {sim_status}" else: status_message = sim_status if heatmap is None: return None, "❌ Could not generate any heatmap", None, None # Create visualization fig, axes = plt.subplots(1, 3, figsize=(15, 5)) # 1. Original image axes[0].imshow(img_array) axes[0].set_title("Original Image", fontsize=12, fontweight='bold') axes[0].axis('off') # 2. Heatmap only im1 = axes[1].imshow(heatmap, cmap=colormap, vmin=0, vmax=1) axes[1].set_title(f"Attention Heatmap ({colormap})", fontsize=12, fontweight='bold') axes[1].axis('off') plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04) # 3. Overlay axes[2].imshow(img_array) im2 = axes[2].imshow(heatmap, cmap=colormap, alpha=0.6, vmin=0, vmax=1, interpolation='bilinear') # Determine title based on success if "✅ Grad-CAM successful" in status_message: title = "🎯 Real AI Attention Overlay" title_color = 'green' else: title = "🎨 Simulated Attention Overlay" title_color = 'orange' axes[2].set_title(title, fontsize=12, fontweight='bold', color=title_color) axes[2].axis('off') plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04) plt.tight_layout() return fig, status_message, stats, debug_info except Exception as e: return None, f"❌ Visualization error: {str(e)}", None, None # Main App def main(): # Header st.markdown('