Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import os | |
| import sys | |
| from PIL import Image | |
| # Set environment variables to fix permission issues | |
| os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib' | |
| os.environ['STREAMLIT_SERVER_HEADLESS'] = 'true' | |
| # Minimal imports to avoid conflicts | |
| try: | |
| import tensorflow as tf | |
| TF_AVAILABLE = True | |
| except ImportError: | |
| TF_AVAILABLE = False | |
| st.error("TensorFlow not available") | |
| try: | |
| import matplotlib | |
| matplotlib.use('Agg') # Use non-interactive backend | |
| import matplotlib.pyplot as plt | |
| import matplotlib.cm as cm | |
| MPL_AVAILABLE = True | |
| except ImportError: | |
| MPL_AVAILABLE = False | |
| # Page config | |
| st.set_page_config( | |
| page_title="Stroke Classifier", | |
| page_icon="π§ ", | |
| layout="wide") | |
| # Simple styling | |
| st.markdown(""" | |
| <style> | |
| .main-header { | |
| font-size: 2.5rem; | |
| color: #1f77b4; | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| } | |
| .prediction-box { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| padding: 2rem; | |
| border-radius: 1rem; | |
| text-align: center; | |
| margin: 1rem 0; | |
| } | |
| .status-box { | |
| padding: 1rem; | |
| border-radius: 0.5rem; | |
| margin: 1rem 0; | |
| } | |
| .success { background-color: #d4edda; border: 1px solid #c3e6cb; color: #155724; } | |
| .error { background-color: #f8d7da; border: 1px solid #f5c6cb; color: #721c24; } | |
| .info { background-color: #d1ecf1; border: 1px solid #bee5eb; color: #0c5460; } | |
| .warning { background-color: #fff3cd; border: 1px solid #ffeaa7; color: #856404; } | |
| .debug { background-color: #f8f9fa; border: 1px solid #dee2e6; color: #495057; font-family: monospace; } | |
| </style>""", unsafe_allow_html=True) | |
| # Initialize session state | |
| if 'model_loaded' not in st.session_state: | |
| st.session_state.model_loaded = False | |
| st.session_state.model = None | |
| st.session_state.model_status = "Not loaded" | |
| STROKE_LABELS = ["Hemorrhagic Stroke", "Ischemic Stroke", "No Stroke"] | |
| def find_model_file(): | |
| """Find the model file in various possible locations.""" | |
| possible_paths = [ | |
| "stroke_classification_model.h5", | |
| "./stroke_classification_model.h5", | |
| "/app/stroke_classification_model.h5", | |
| "src/stroke_classification_model.h5", | |
| os.path.join(os.getcwd(), "stroke_classification_model.h5") | |
| ] | |
| # Also check all .h5 files in current directory and subdirectories | |
| for root, dirs, files in os.walk('.'): | |
| for file in files: | |
| if file.endswith('.h5'): | |
| possible_paths.append(os.path.join(root, file)) | |
| for path in possible_paths: | |
| if os.path.exists(path): | |
| return path | |
| return None | |
| def load_stroke_model(): | |
| """Load model with caching.""" | |
| if not TF_AVAILABLE: | |
| return None, "β TensorFlow not available" | |
| try: | |
| # Find the model file | |
| model_path = find_model_file() | |
| if model_path is None: | |
| # List all files to help debug | |
| current_files = [] | |
| for root, dirs, files in os.walk('.'): | |
| for file in files: | |
| current_files.append(os.path.join(root, file)) | |
| return None, f"β Model file not found. Available files: {current_files[:10]}" | |
| st.info(f"Found model at: {model_path}") | |
| # Load model with minimal custom objects | |
| model = tf.keras.models.load_model(model_path, compile=False) | |
| return model, f"β Model loaded successfully from: {model_path}" | |
| except Exception as e: | |
| return None, f"β Model loading failed: {str(e)}" | |
| def analyze_model_architecture(model): | |
| """Comprehensive analysis of model architecture.""" | |
| if model is None: | |
| return {"error": "No model loaded"} | |
| layer_analysis = { | |
| 'total_layers': len(model.layers), | |
| 'conv_layers': [], | |
| 'dense_layers': [], | |
| 'other_layers': [], | |
| 'all_layers_detailed': [], | |
| 'model_type': 'Unknown' | |
| } | |
| for i, layer in enumerate(model.layers): | |
| layer_type = type(layer).__name__ | |
| # Get more detailed layer information | |
| layer_info = { | |
| 'index': i, | |
| 'name': layer.name, | |
| 'type': layer_type, | |
| 'output_shape': getattr(layer, 'output_shape', 'Unknown'), | |
| 'trainable': getattr(layer, 'trainable', 'Unknown'), | |
| 'activation': getattr(layer, 'activation', None) | |
| } | |
| # Try to get activation function name | |
| if hasattr(layer, 'activation') and layer.activation: | |
| try: | |
| layer_info['activation'] = layer.activation.__name__ | |
| except: | |
| layer_info['activation'] = str(layer.activation) | |
| layer_analysis['all_layers_detailed'].append(layer_info) | |
| # Categorize layers with more comprehensive detection | |
| if any(conv_type in layer_type for conv_type in [ | |
| 'Conv1D', 'Conv2D', 'Conv3D', 'SeparableConv2D', 'DepthwiseConv2D', | |
| 'Convolution1D', 'Convolution2D', 'Convolution3D' | |
| ]) or 'conv' in layer.name.lower(): | |
| layer_analysis['conv_layers'].append(layer_info) | |
| elif 'Dense' in layer_type or 'Linear' in layer_type: | |
| layer_analysis['dense_layers'].append(layer_info) | |
| else: | |
| layer_analysis['other_layers'].append(layer_info) | |
| # Determine model type | |
| if layer_analysis['conv_layers']: | |
| layer_analysis['model_type'] = 'CNN (Convolutional Neural Network)' | |
| elif layer_analysis['dense_layers']: | |
| layer_analysis['model_type'] = 'MLP (Multi-Layer Perceptron)' | |
| else: | |
| layer_analysis['model_type'] = 'Custom Architecture' | |
| return layer_analysis | |
| def debug_gradcam_step_by_step(img_array, model, layer_name, pred_index): | |
| """Debug Grad-CAM computation step by step.""" | |
| debug_info = { | |
| 'step': 'Starting', | |
| 'error': None, | |
| 'layer_output_shape': None, | |
| 'gradients_shape': None, | |
| 'gradients_stats': None, | |
| 'heatmap_stats': None | |
| } | |
| try: | |
| debug_info['step'] = 'Getting target layer' | |
| target_layer = model.get_layer(layer_name) | |
| debug_info['target_layer_type'] = type(target_layer).__name__ | |
| debug_info['step'] = 'Creating grad model' | |
| grad_model = tf.keras.Model( | |
| inputs=[model.inputs], | |
| outputs=[target_layer.output, model.output] | |
| ) | |
| debug_info['step'] = 'Computing forward pass' | |
| with tf.GradientTape() as tape: | |
| layer_output, preds = grad_model(img_array) | |
| debug_info['layer_output_shape'] = layer_output.shape.as_list() | |
| debug_info['predictions_shape'] = preds.shape.as_list() | |
| if pred_index is None: | |
| pred_index = tf.argmax(preds[0]) | |
| debug_info['pred_index'] = int(pred_index) | |
| debug_info['pred_confidence'] = float(preds[0][pred_index]) | |
| class_channel = preds[:, pred_index] | |
| debug_info['class_channel_shape'] = class_channel.shape.as_list() | |
| debug_info['step'] = 'Computing gradients' | |
| grads = tape.gradient(class_channel, layer_output) | |
| if grads is None: | |
| debug_info['error'] = "Gradients are None - no backpropagation path" | |
| return None, debug_info | |
| debug_info['gradients_shape'] = grads.shape.as_list() | |
| debug_info['gradients_stats'] = { | |
| 'min': float(tf.reduce_min(grads)), | |
| 'max': float(tf.reduce_max(grads)), | |
| 'mean': float(tf.reduce_mean(grads)), | |
| 'std': float(tf.math.reduce_std(grads)) | |
| } | |
| debug_info['step'] = 'Processing gradients based on layer type' | |
| if len(layer_output.shape) == 4: # Conv layer | |
| debug_info['processing_type'] = 'Convolutional layer (4D)' | |
| pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) | |
| layer_output = layer_output[0] | |
| heatmap = layer_output @ pooled_grads[..., tf.newaxis] | |
| heatmap = tf.squeeze(heatmap) | |
| elif len(layer_output.shape) == 2: # Dense layer | |
| debug_info['processing_type'] = 'Dense layer (2D)' | |
| # For dense layers, create spatial heatmap from gradient magnitude | |
| grads_magnitude = tf.reduce_mean(tf.abs(grads)) | |
| # Create a simple spatial pattern | |
| heatmap = tf.ones((14, 14)) * grads_magnitude | |
| else: | |
| debug_info['error'] = f"Unsupported layer shape: {layer_output.shape}" | |
| return None, debug_info | |
| debug_info['step'] = 'Normalizing heatmap' | |
| debug_info['raw_heatmap_stats'] = { | |
| 'min': float(tf.reduce_min(heatmap)), | |
| 'max': float(tf.reduce_max(heatmap)), | |
| 'mean': float(tf.reduce_mean(heatmap)), | |
| 'std': float(tf.math.reduce_std(heatmap)) | |
| } | |
| # Apply ReLU (remove negative values) | |
| heatmap = tf.maximum(heatmap, 0) | |
| # Normalize | |
| heatmap_max = tf.reduce_max(heatmap) | |
| if heatmap_max > 0: | |
| heatmap = heatmap / heatmap_max | |
| else: | |
| debug_info['error'] = "All heatmap values are zero or negative" | |
| return None, debug_info | |
| debug_info['final_heatmap_stats'] = { | |
| 'min': float(tf.reduce_min(heatmap)), | |
| 'max': float(tf.reduce_max(heatmap)), | |
| 'mean': float(tf.reduce_mean(heatmap)), | |
| 'std': float(tf.math.reduce_std(heatmap)) | |
| } | |
| debug_info['step'] = 'Complete' | |
| return heatmap.numpy(), debug_info | |
| except Exception as e: | |
| debug_info['error'] = f"Exception in step '{debug_info['step']}': {str(e)}" | |
| return None, debug_info | |
| def create_robust_gradcam_heatmap(img, model, predictions): | |
| """Create Grad-CAM with comprehensive debugging.""" | |
| try: | |
| # Preprocess image | |
| img_resized = img.resize((224, 224)) | |
| img_array = np.array(img_resized, dtype=np.float32) | |
| # Handle grayscale | |
| if len(img_array.shape) == 2: | |
| img_array = np.stack([img_array] * 3, axis=-1) | |
| # Normalize and add batch dimension | |
| img_array = np.expand_dims(img_array, axis=0) / 255.0 | |
| # Get model analysis | |
| analysis = analyze_model_architecture(model) | |
| # Try different layers in order of preference | |
| layer_candidates = [] | |
| # Add conv layers first | |
| for layer in analysis['conv_layers']: | |
| layer_candidates.append((layer['name'], f"Conv layer: {layer['name']}")) | |
| # Add other potentially suitable layers | |
| for layer in analysis['all_layers_detailed']: | |
| if (layer['type'] in ['Activation', 'BatchNormalization'] and | |
| isinstance(layer['output_shape'], (list, tuple)) and | |
| len(layer['output_shape']) == 4): | |
| layer_candidates.append((layer['name'], f"4D layer: {layer['name']} ({layer['type']})")) | |
| # Try dense layers as last resort | |
| if not layer_candidates: | |
| for layer in analysis['dense_layers']: | |
| layer_candidates.append((layer['name'], f"Dense layer: {layer['name']} (experimental)")) | |
| if not layer_candidates: | |
| return None, "β No suitable layers found", None | |
| # Try each candidate layer | |
| for layer_name, layer_desc in layer_candidates: | |
| pred_index = np.argmax(predictions) | |
| heatmap, debug_info = debug_gradcam_step_by_step( | |
| img_array, model, layer_name, pred_index | |
| ) | |
| if heatmap is not None: | |
| # Resize heatmap to match input image size | |
| if heatmap.shape[0] != 224 or heatmap.shape[1] != 224: | |
| heatmap_resized = tf.image.resize( | |
| heatmap[..., tf.newaxis], | |
| (224, 224) | |
| ).numpy()[:, :, 0] | |
| else: | |
| heatmap_resized = heatmap | |
| # Final statistics | |
| stats = { | |
| 'min': float(np.min(heatmap_resized)), | |
| 'max': float(np.max(heatmap_resized)), | |
| 'mean': float(np.mean(heatmap_resized)), | |
| 'std': float(np.std(heatmap_resized)) | |
| } | |
| return heatmap_resized, f"β Grad-CAM successful using {layer_desc}", stats, debug_info | |
| else: | |
| # Continue to next layer if this one failed | |
| continue | |
| # If all layers failed, return debug info from the last attempt | |
| return None, f"β All layers failed. Last error: {debug_info.get('error', 'Unknown')}", None, debug_info | |
| except Exception as e: | |
| return None, f"β Grad-CAM error: {str(e)}", None, {'error': str(e)} | |
| def predict_stroke(img, model): | |
| """Predict stroke type from image.""" | |
| if model is None: | |
| return None, "Model not loaded" | |
| try: | |
| # Preprocess image | |
| img_resized = img.resize((224, 224)) | |
| img_array = np.array(img_resized, dtype=np.float32) | |
| # Handle grayscale | |
| if len(img_array.shape) == 2: | |
| img_array = np.stack([img_array] * 3, axis=-1) | |
| # Normalize and add batch dimension | |
| img_array = np.expand_dims(img_array, axis=0) / 255.0 | |
| # Predict | |
| predictions = model.predict(img_array, verbose=0) | |
| return predictions[0], None | |
| except Exception as e: | |
| return None, f"Prediction error: {str(e)}" | |
| def create_enhanced_simulated_heatmap(img, predictions): | |
| """Create a more realistic simulated heatmap.""" | |
| try: | |
| confidence = np.max(predictions) | |
| predicted_class = np.argmax(predictions) | |
| # Create different patterns based on predicted class | |
| if predicted_class == 0: # Hemorrhagic | |
| # Focus on center-left region | |
| center_x, center_y = 80, 112 | |
| elif predicted_class == 1: # Ischemic | |
| # Focus on right side | |
| center_x, center_y = 150, 112 | |
| else: # No stroke | |
| # Diffuse, low-intensity pattern | |
| center_x, center_y = 112, 112 | |
| # Create base pattern | |
| y, x = np.ogrid[:224, :224] | |
| # Primary focus area | |
| mask1 = np.exp(-((x - center_x)**2 + (y - center_y)**2) / (2 * (40**2))) | |
| # Secondary areas | |
| mask2 = np.exp(-((x - center_x + 30)**2 + (y - center_y + 20)**2) / (2 * (25**2))) | |
| mask3 = np.exp(-((x - center_x - 20)**2 + (y - center_y - 30)**2) / (2 * (30**2))) | |
| # Combine patterns | |
| heatmap = (mask1 * 0.8 + mask2 * 0.4 + mask3 * 0.3) * confidence | |
| # Add some noise for realism | |
| np.random.seed(42) | |
| noise = np.random.normal(0, 0.05, heatmap.shape) | |
| heatmap = np.maximum(heatmap + noise, 0) | |
| # Normalize | |
| if np.max(heatmap) > 0: | |
| heatmap = heatmap / np.max(heatmap) | |
| stats = { | |
| 'min': float(np.min(heatmap)), | |
| 'max': float(np.max(heatmap)), | |
| 'mean': float(np.mean(heatmap)), | |
| 'std': float(np.std(heatmap)) | |
| } | |
| return heatmap, "β οΈ Using enhanced simulated heatmap", stats | |
| except Exception as e: | |
| return None, f"β Simulated heatmap error: {str(e)}", None | |
| def create_comprehensive_visualization(img, predictions, model, force_gradcam=True, colormap='hot'): | |
| """Create comprehensive visualization with debugging.""" | |
| if not MPL_AVAILABLE: | |
| return None, "β Matplotlib not available" | |
| try: | |
| # Resize image to 224x224 | |
| img_resized = img.resize((224, 224)) | |
| img_array = np.array(img_resized) | |
| heatmap = None | |
| status_message = "" | |
| stats = None | |
| debug_info = None | |
| # Try Grad-CAM first | |
| if force_gradcam and model is not None: | |
| result = create_robust_gradcam_heatmap(img, model, predictions) | |
| if result and len(result) >= 3: | |
| heatmap, gradcam_status, stats = result[0], result[1], result[2] | |
| if len(result) > 3: | |
| debug_info = result[3] | |
| status_message = gradcam_status | |
| # Fallback to enhanced simulated if Grad-CAM failed | |
| if heatmap is None: | |
| result = create_enhanced_simulated_heatmap(img, predictions) | |
| if result and len(result) == 3: | |
| heatmap, sim_status, stats = result | |
| if status_message: | |
| status_message += f" | {sim_status}" | |
| else: | |
| status_message = sim_status | |
| if heatmap is None: | |
| return None, "β Could not generate any heatmap", None, None | |
| # Create visualization | |
| fig, axes = plt.subplots(1, 3, figsize=(15, 5)) | |
| # 1. Original image | |
| axes[0].imshow(img_array) | |
| axes[0].set_title("Original Image", fontsize=12, fontweight='bold') | |
| axes[0].axis('off') | |
| # 2. Heatmap only | |
| im1 = axes[1].imshow(heatmap, cmap=colormap, vmin=0, vmax=1) | |
| axes[1].set_title(f"Attention Heatmap ({colormap})", fontsize=12, fontweight='bold') | |
| axes[1].axis('off') | |
| plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04) | |
| # 3. Overlay | |
| axes[2].imshow(img_array) | |
| im2 = axes[2].imshow(heatmap, cmap=colormap, alpha=0.6, vmin=0, vmax=1, interpolation='bilinear') | |
| # Determine title based on success | |
| if "β Grad-CAM successful" in status_message: | |
| title = "π― Real AI Attention Overlay" | |
| title_color = 'green' | |
| else: | |
| title = "π¨ Simulated Attention Overlay" | |
| title_color = 'orange' | |
| axes[2].set_title(title, fontsize=12, fontweight='bold', color=title_color) | |
| axes[2].axis('off') | |
| plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04) | |
| plt.tight_layout() | |
| return fig, status_message, stats, debug_info | |
| except Exception as e: | |
| return None, f"β Visualization error: {str(e)}", None, None | |
| # Main App | |
| def main(): | |
| # Header | |
| st.markdown('<h1 class="main-header">π§ AI-Powered Stroke Classification System</h1>', unsafe_allow_html=True) | |
| # Auto-load model on startup | |
| if not st.session_state.model_loaded: | |
| with st.spinner("Loading AI model..."): | |
| st.session_state.model, st.session_state.model_status = load_stroke_model() | |
| st.session_state.model_loaded = True | |
| # System status | |
| st.markdown("### π§ System Status") | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| if TF_AVAILABLE: | |
| st.markdown('<div class="status-box success">β TensorFlow Ready</div>', unsafe_allow_html=True) | |
| st.write(f"TF Version: {tf.__version__}") | |
| else: | |
| st.markdown('<div class="status-box error">β TensorFlow Error</div>', unsafe_allow_html=True) | |
| with col2: | |
| if MPL_AVAILABLE: | |
| st.markdown('<div class="status-box success">β Matplotlib Ready</div>', unsafe_allow_html=True) | |
| else: | |
| st.markdown('<div class="status-box error">β Matplotlib Error</div>', unsafe_allow_html=True) | |
| with col3: | |
| if "β " in st.session_state.model_status: | |
| st.markdown('<div class="status-box success">β Model Loaded</div>', unsafe_allow_html=True) | |
| else: | |
| st.markdown('<div class="status-box error">β Model Error</div>', unsafe_allow_html=True) | |
| # Model status details | |
| st.markdown(f'<div class="status-box info"><strong>Model Status:</strong> {st.session_state.model_status}</div>', unsafe_allow_html=True) | |
| # Enhanced model architecture analysis | |
| if st.session_state.model is not None: | |
| with st.expander("π Detailed Model Architecture Analysis"): | |
| analysis = analyze_model_architecture(st.session_state.model) | |
| st.write("**π Model Summary:**") | |
| st.write(f"- **Model Type:** {analysis['model_type']}") | |
| st.write(f"- **Total Layers:** {analysis['total_layers']}") | |
| st.write(f"- **Convolutional Layers:** {len(analysis['conv_layers'])}") | |
| st.write(f"- **Dense Layers:** {len(analysis['dense_layers'])}") | |
| st.write(f"- **Other Layers:** {len(analysis['other_layers'])}") | |
| # Show detailed layer information | |
| st.write("**π All Layers (Detailed):**") | |
| for layer in analysis['all_layers_detailed']: | |
| activation_info = f" | Activation: {layer['activation']}" if layer['activation'] else "" | |
| st.code(f"{layer['index']:2d}: {layer['name']} ({layer['type']}) | Shape: {layer['output_shape']}{activation_info}") | |
| # Manual reload button | |
| if st.button("π Reload Model", help="Try to reload the model"): | |
| st.session_state.model_loaded = False | |
| st.rerun() | |
| # Sidebar | |
| with st.sidebar: | |
| st.header("π€ Upload Brain Scan") | |
| uploaded_file = st.file_uploader( | |
| "Choose a brain scan image...", | |
| type=['png', 'jpg', 'jpeg', 'bmp', 'tiff'], | |
| help="Upload a brain scan image for stroke classification" | |
| ) | |
| st.markdown("---") | |
| st.header("π¨ Visualization Options") | |
| force_gradcam = st.checkbox( | |
| "Attempt Grad-CAM", | |
| value=True, | |
| help="Try Grad-CAM with comprehensive debugging" | |
| ) | |
| colormap = st.selectbox( | |
| "Color Scheme", | |
| ['hot', 'jet', 'viridis', 'plasma', 'inferno', 'magma', 'coolwarm'], | |
| index=0, | |
| help="Choose color scheme for heatmap visualization" | |
| ) | |
| show_probabilities = st.checkbox("Show All Probabilities", value=True) | |
| show_debug = st.checkbox("Show Debug Info", value=True) | |
| show_stats = st.checkbox("Show Heatmap Statistics", value=True) | |
| show_detailed_debug = st.checkbox("Show Detailed Debug Info", value=False) | |
| if uploaded_file is not None: | |
| # Load image | |
| image = Image.open(uploaded_file) | |
| # Main content area | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.subheader("π Classification Results") | |
| if st.session_state.model is not None: | |
| # Predict | |
| with st.spinner("π Analyzing brain scan..."): | |
| predictions, error = predict_stroke(image, st.session_state.model) | |
| if error: | |
| st.error(error) | |
| else: | |
| # Get top prediction | |
| class_idx = np.argmax(predictions) | |
| confidence = predictions[class_idx] * 100 | |
| predicted_class = STROKE_LABELS[class_idx] | |
| # Display main result | |
| st.markdown(f""" | |
| <div class="prediction-box"> | |
| <h2>{predicted_class}</h2> | |
| <h3>Confidence: {confidence:.1f}%</h3> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Show all probabilities | |
| if show_probabilities: | |
| st.write("**π All Probabilities:**") | |
| for i, (label, prob) in enumerate(zip(STROKE_LABELS, predictions)): | |
| st.write(f"β’ {label}: {prob*100:.1f}%") | |
| else: | |
| st.error("β Model not loaded. Check the debug information above to see available files.") | |
| with col2: | |
| st.subheader("π― Comprehensive AI Attention Visualization") | |
| if st.session_state.model is not None and 'predictions' in locals() and predictions is not None: | |
| # Create comprehensive visualization | |
| with st.spinner("π¨ Generating comprehensive attention visualization..."): | |
| result = create_comprehensive_visualization( | |
| image, | |
| predictions, | |
| st.session_state.model, | |
| force_gradcam, | |
| colormap | |
| ) | |
| if result and len(result) >= 2: | |
| overlay_fig, status_message = result[0], result[1] | |
| stats = result[2] if len(result) > 2 else None | |
| debug_info = result[3] if len(result) > 3 else None | |
| if overlay_fig is not None: | |
| st.pyplot(overlay_fig) | |
| plt.close() | |
| # Show detailed status | |
| if show_debug: | |
| if "β Grad-CAM successful" in status_message: | |
| st.success(f"β {status_message}") | |
| elif "β οΈ" in status_message: | |
| st.warning(f"β οΈ {status_message}") | |
| else: | |
| st.error(f"β {status_message}") | |
| # Show heatmap statistics | |
| if show_stats and stats: | |
| st.write("**π Heatmap Statistics:**") | |
| if any(np.isnan([stats['min'], stats['max'], stats['mean'], stats['std']])): | |
| st.error("β οΈ NaN values detected in heatmap - this indicates a computation error") | |
| else: | |
| col_stats1, col_stats2 = st.columns(2) | |
| with col_stats1: | |
| st.write(f"β’ Min: {stats['min']:.3f}") | |
| st.write(f"β’ Max: {stats['max']:.3f}") | |
| with col_stats2: | |
| st.write(f"β’ Mean: {stats['mean']:.3f}") | |
| st.write(f"β’ Std: {stats['std']:.3f}") | |
| # Show detailed debug information | |
| if show_detailed_debug and debug_info: | |
| with st.expander("π§ Detailed Debug Information"): | |
| st.json(debug_info) | |
| else: | |
| st.error(f"Could not generate visualization: {status_message}") | |
| if debug_info: | |
| st.error(f"Debug info: {debug_info.get('error', 'No additional info')}") | |
| else: | |
| st.error("Could not generate attention visualization") | |
| else: | |
| st.info("Upload an image and run classification to see AI attention visualization") | |
| else: | |
| # Welcome message | |
| st.markdown(""" | |
| ## π Welcome to the Comprehensive Stroke Classification System | |
| This system now includes **step-by-step debugging** to identify why Grad-CAM might be failing. | |
| ### π§ New Debugging Features: | |
| - **Step-by-step Grad-CAM debugging** - See exactly where it fails | |
| - **Multiple layer attempts** - Tries different layers automatically | |
| - **Enhanced error messages** - Clear explanations of what went wrong | |
| - **NaN detection** - Identifies computation errors | |
| ### π― What to Look For: | |
| - **Green success messages** - Grad-CAM is working | |
| - **Orange warnings** - Using fallback methods | |
| - **Red errors** - Something is broken | |
| - **NaN statistics** - Computation failure | |
| **Upload an image to see detailed debugging! π** | |
| """) | |
| # Medical disclaimer | |
| st.markdown("---") | |
| st.warning("β οΈ **Medical Disclaimer:** This AI system is for educational and research purposes only. It should not be used for actual medical diagnosis. Always consult qualified healthcare professionals for medical decisions.") | |
| if __name__ == "__main__": | |
| main() | |