import gradio as gr import torch import torch.nn as nn import torchvision.transforms as transforms from torchvision import models from torchvision.models.resnet import Bottleneck from PIL import Image import numpy as np import matplotlib.pyplot as plt import seaborn as sns import plotly.graph_objects as go import plotly.express as px from plotly.subplots import make_subplots import io import base64 from datetime import datetime import warnings warnings.filterwarnings('ignore') # Set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Model Architecture class SELayer(nn.Module): def __init__(self, channel, reduction=16): super(SELayer, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channel, channel // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(channel // reduction, channel, bias=False), nn.Sigmoid() ) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y.expand_as(x) class SEBottleneck(Bottleneck): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None, se_reduction=16): super(SEBottleneck, self).__init__(inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer) self.se = SELayer(planes * self.expansion, reduction=se_reduction) def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) out = self.se(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out def get_seresnext50(num_classes=1, se_reduction=16): """Create SE-ResNeXt50 model""" model = models.resnext50_32x4d(pretrained=True) base_width = model.base_width def replace_bottlenecks(module, se_reduction_ratio, base_width): for name, child_module in module.named_children(): if isinstance(child_module, Bottleneck): inplanes = child_module.conv1.in_channels planes = child_module.conv3.out_channels // child_module.expansion stride = child_module.stride downsample = child_module.downsample groups = child_module.conv2.groups dilation = child_module.conv2.dilation[0] new_bottleneck = SEBottleneck( inplanes=inplanes, planes=planes, stride=stride, downsample=downsample, groups=groups, base_width=base_width, dilation=dilation, se_reduction=se_reduction_ratio ) new_bottleneck.load_state_dict(child_module.state_dict(), strict=False) setattr(module, name, new_bottleneck) else: replace_bottlenecks(child_module, se_reduction_ratio, base_width) replace_bottlenecks(model, se_reduction, base_width) # Replace final layer for binary classification (single output) in_features = model.fc.in_features model.fc = nn.Linear(in_features, num_classes) return model # Load the trained model @torch.no_grad() def load_model(): """Load the trained SE-ResNeXt50 model""" model_path = 'best_seresnext50_model.pth' # Relative path # Create model model = get_seresnext50(num_classes=1, se_reduction=16) # Load checkpoint checkpoint = torch.load(model_path, map_location=device) if 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] # Remove 'module.' prefix if present if any(key.startswith('module.') for key in state_dict.keys()): state_dict = {key.replace('module.', ''): value for key, value in state_dict.items()} model.load_state_dict(state_dict, strict=False) else: model.load_state_dict(checkpoint) model.to(device) model.eval() return model # Initialize model print("Loading DeepStroke Model...") model = load_model() print(f"Model loaded successfully on {device}") # Image preprocessing transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Optimal threshold from evaluation (Youden's Index) OPTIMAL_THRESHOLD = 0.4902 def predict_stroke(image): """ Predict stroke probability from brain CT image Args: image: PIL Image or numpy array Returns: tuple: (prediction_text, probability, confidence, detailed_analysis, visualization) """ try: # Check if image is None if image is None: return "No image provided", "N/A", "N/A", "Please upload an image to analyze.", None print(f"DEBUG: Processing image type: {type(image)}") # Debug print # Convert to PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image) # Convert to RGB if needed if image.mode != 'RGB': image = image.convert('RGB') print(f"DEBUG: Image size: {image.size}, mode: {image.mode}") # Debug print # Preprocess image image_tensor = transform(image).unsqueeze(0).to(device) # Get prediction with torch.no_grad(): output = model(image_tensor) probability = torch.sigmoid(output).item() # Use optimal threshold for prediction prediction = probability >= OPTIMAL_THRESHOLD confidence = max(probability, 1 - probability) # Create prediction text with clinical interpretation if prediction: risk_level = "HIGH RISK" if probability > 0.8 else "MODERATE RISK" if probability > 0.6 else "ELEVATED RISK" prediction_text = f"🚨 **STROKE DETECTED** - {risk_level}" color = "#ff4444" recommendation = "⚠️ **URGENT**: Immediate medical attention required. Contact emergency services." else: if probability < 0.2: risk_level = "LOW RISK" recommendation = "✅ **Low stroke probability detected**. Continue routine medical care." elif probability < 0.4: risk_level = "MILD RISK" recommendation = "⚠️ **Mild concern**. Consider consulting with a neurologist." else: risk_level = "UNCERTAIN" recommendation = "⚠️ **Borderline case**. Additional imaging or clinical assessment recommended." prediction_text = f"✅ **NO STROKE DETECTED** - {risk_level}" color = "#44ff44" # Detailed analysis detailed_analysis = f""" **🔬 DETAILED ANALYSIS** **Prediction:** {prediction_text} **Stroke Probability:** {probability:.1%} **Model Confidence:** {confidence:.1%} **Risk Assessment:** {risk_level} **Clinical Threshold:** {OPTIMAL_THRESHOLD:.1%} (Optimized using Youden's Index) **Model Performance:** ROC-AUC 0.98+ on validation data **⚕️ Clinical Recommendation:** {recommendation} **📋 Important Notes:** - This AI model is for assistance only and should not replace professional medical diagnosis - Always consult with qualified medical professionals for definitive diagnosis - Consider patient clinical history and symptoms alongside AI predictions """ # Create visualization visualization = create_prediction_visualization(probability, prediction, confidence) return prediction_text, f"{probability:.1%}", f"{confidence:.1%}", detailed_analysis, visualization except Exception as e: import traceback error_msg = f"Error during prediction: {str(e)}\n\nTraceback: {traceback.format_exc()}" print(f"DEBUG: Prediction error - {error_msg}") # Debug print return f"⚠️ **PREDICTION ERROR**\n\n{str(e)}", "Error", "N/A", error_msg, None def create_prediction_visualization(probability, prediction, confidence): """Create a simple and clean visualization with essential information""" try: # Simple color scheme color_safe = '#28a745' # Green color_warning = '#ffc107' # Yellow color_danger = '#dc3545' # Red # Determine colors based on prediction gauge_color = color_danger if prediction else color_safe # Create a single-panel dashboard with just the probability gauge fig = go.Figure() # Main probability gauge - simplified fig.add_trace( go.Indicator( mode="gauge+number+delta", value=probability * 100, title={ 'text': "Stroke Probability", 'font': {'size': 18, 'family': 'Arial'} }, number={ 'font': {'size': 32, 'color': gauge_color, 'family': 'Arial Black'}, 'suffix': '%' }, delta={ 'reference': OPTIMAL_THRESHOLD * 100, 'increasing': {'color': color_danger}, 'decreasing': {'color': color_safe} }, gauge={ 'axis': { 'range': [0, 100], 'tickfont': {'size': 14} }, 'bar': { 'color': gauge_color, 'thickness': 0.8 }, 'steps': [ {'range': [0, OPTIMAL_THRESHOLD * 100], 'color': "rgba(40, 167, 69, 0.2)"}, {'range': [OPTIMAL_THRESHOLD * 100, 100], 'color': "rgba(220, 53, 69, 0.2)"} ], 'threshold': { 'line': {'color': "#000", 'width': 3}, 'thickness': 0.8, 'value': OPTIMAL_THRESHOLD * 100 } } ) ) # Add simple status annotation status_text = "⚠️ STROKE DETECTED" if prediction else "✅ NO STROKE DETECTED" status_color = color_danger if prediction else color_safe fig.add_annotation( x=0.5, y=0.1, xref="paper", yref="paper", text=f"{status_text}
" + f"Confidence: {confidence:.0%}
" + f"Threshold: {OPTIMAL_THRESHOLD:.0%}", showarrow=False, font={'size': 14, 'color': status_color}, bgcolor="rgba(255,255,255,0.9)", bordercolor=status_color, borderwidth=2, borderpad=10, xanchor="center" ) # Simple layout configuration fig.update_layout( height=400, # Much smaller height showlegend=False, title={ 'text': "🧠 Stroke Detection Result", 'x': 0.5, 'y': 0.95, 'font': {'size': 24, 'family': 'Arial Black', 'color': '#2C3E50'} }, font={'size': 12, 'family': 'Arial'}, margin=dict(t=60, b=40, l=40, r=40), plot_bgcolor='white', paper_bgcolor='white' ) return fig except Exception as e: import traceback print(f"DEBUG: Visualization error - {str(e)}\n{traceback.format_exc()}") # Return a simple figure on error fig = go.Figure() fig.add_annotation( x=0.5, y=0.5, text=f"Visualization Error: {str(e)}", showarrow=False, font={'size': 16, 'color': 'red'} ) return fig def create_model_info(): """Create enhanced model information display with interactive elements""" info_html = """

🧠 DeepStroke AI - Advanced Brain CT Stroke Detection

🏗️ Model Architecture

  • 🔬 Network: SE-ResNeXt50
  • 📐 Input Size: 224×224 RGB
  • ⚙️ Parameters: ~25M trained
  • 🎯 Threshold: 49.02% (Optimized)
  • 🧮 SE Ratio: 16 (Attention)

📊 Performance Metrics

  • 🎯 ROC-AUC: 0.98+ (Excellent)
  • 🔍 Sensitivity: High stroke detection
  • ✅ Specificity: Low false alarms
  • 📋 Validation: External datasets
  • ⚡ Speed: <1s inference

🚀 Latest Model Version 1.0 | Trained on 10,000+ validated NON-CONTRAST brain CT scans

⚠️ ONLY for non-contrast brain CT imaging

""" return info_html def create_clinical_guidelines(): """Create enhanced clinical guidelines display with improved interactivity""" guidelines_html = """

⚕️ Clinical Usage Guidelines & Safety Protocols

🚨 CRITICAL SAFETY REMINDERS

✅ BEST PRACTICES

⚠️ IMAGING LIMITATIONS

🚨 This model will FAIL on:

Using inappropriate image types will produce unreliable results!

""" return guidelines_html # Create Gradio Interface def create_gradio_app(): """Create the main Gradio application with improved styling""" # Custom CSS for medical-grade styling with dark mode support and animations custom_css = """ .gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; max-width: 1400px !important; margin: auto; animation: fadeIn 0.5s ease-in; } @keyframes fadeIn { from { opacity: 0; transform: translateY(10px); } to { opacity: 1; transform: translateY(0); } } @keyframes pulse { 0% { transform: scale(1); } 50% { transform: scale(1.02); } 100% { transform: scale(1); } } @keyframes shimmer { 0% { background-position: -200px 0; } 100% { background-position: 200px 0; } } /* Clinical Guidelines Styling - Dark Mode Compatible */ .clinical-guidelines { background: var(--background-fill-primary, #ffffff); border: 2px solid var(--border-color-primary, #e0e0e0); border-left: 5px solid #007bff; padding: 25px; border-radius: 15px; margin: 15px 0; box-shadow: 0 4px 15px rgba(0,0,0,0.1); color: var(--body-text-color, #000000); transition: all 0.3s ease; position: relative; overflow: hidden; } .clinical-guidelines:hover { transform: translateY(-2px); box-shadow: 0 6px 20px rgba(0,0,0,0.15); } .clinical-guidelines::before { content: ''; position: absolute; top: -2px; left: -2px; right: -2px; bottom: -2px; background: linear-gradient(45deg, #007bff, #28a745, #ffc107, #dc3545); border-radius: 15px; z-index: -1; opacity: 0; transition: opacity 0.3s ease; } .clinical-guidelines:hover::before { opacity: 0.1; } .guidelines-title { color: #007bff !important; margin-top: 0 !important; font-weight: 700; font-size: 1.3em; text-shadow: 0 1px 2px rgba(0,0,0,0.1); } .critical-section { color: #dc3545 !important; font-weight: 700; margin-bottom: 10px; font-size: 1.1em; display: flex; align-items: center; gap: 8px; } .best-practices-section { color: #28a745 !important; font-weight: 700; margin-bottom: 10px; font-size: 1.1em; display: flex; align-items: center; gap: 8px; } .requirements-section { color: #fd7e14 !important; font-weight: 700; margin-bottom: 10px; font-size: 1.1em; display: flex; align-items: center; gap: 8px; } .critical-list { color: #dc3545 !important; font-weight: 600; line-height: 1.6; } .critical-list li { margin-bottom: 8px; padding-left: 5px; border-left: 3px solid #dc3545; padding-left: 10px; margin-left: 5px; } .guidelines-list { color: var(--body-text-color, #333333) !important; opacity: 0.9; line-height: 1.6; } .guidelines-list li { margin-bottom: 6px; padding-left: 5px; transition: all 0.2s ease; } .guidelines-list li:hover { transform: translateX(5px); color: #007bff !important; } /* Model Info Box - Enhanced with gradients and animations */ .model-info { background: linear-gradient(135deg, #667eea 0%, #764ba2 50%, #667eea 100%); background-size: 200% 200%; animation: gradientShift 6s ease infinite; padding: 25px; border-radius: 15px; color: white; margin: 15px 0; box-shadow: 0 6px 20px rgba(0,0,0,0.2); position: relative; overflow: hidden; } @keyframes gradientShift { 0% { background-position: 0% 50%; } 50% { background-position: 100% 50%; } 100% { background-position: 0% 50%; } } .model-info::before { content: ''; position: absolute; top: 0; left: -100%; width: 100%; height: 100%; background: linear-gradient(90deg, transparent, rgba(255,255,255,0.2), transparent); animation: shimmer 2s infinite; } .model-info:hover { transform: scale(1.02); transition: transform 0.3s ease; } .model-info h2 { margin-top: 0; text-align: center; text-shadow: 0 2px 4px rgba(0,0,0,0.3); font-size: 1.8em; margin-bottom: 20px; } .model-specs-grid { display: grid; grid-template-columns: 1fr 1fr; gap: 25px; margin-top: 20px; } .model-specs-grid h3 { margin-bottom: 15px; font-size: 1.2em; border-bottom: 2px solid rgba(255,255,255,0.3); padding-bottom: 8px; } .model-specs-grid ul { list-style: none; padding: 0; } .model-specs-grid li { margin-bottom: 8px; padding: 8px; background: rgba(255,255,255,0.1); border-radius: 8px; transition: all 0.3s ease; } .model-specs-grid li:hover { background: rgba(255,255,255,0.2); transform: translateX(5px); } /* Header gradient with enhanced effects */ .header-gradient { background: linear-gradient(135deg, #667eea 0%, #764ba2 50%, #667eea 100%); background-size: 200% 200%; animation: gradientShift 8s ease infinite; color: white; padding: 30px; border-radius: 15px; text-align: center; margin-bottom: 25px; box-shadow: 0 6px 25px rgba(0,0,0,0.2); position: relative; overflow: hidden; } .header-gradient::before { content: ''; position: absolute; top: 0; left: -100%; width: 100%; height: 100%; background: linear-gradient(90deg, transparent, rgba(255,255,255,0.1), transparent); animation: shimmer 3s infinite; } .header-gradient h1 { margin: 0; font-size: 2.8em; text-shadow: 0 3px 6px rgba(0,0,0,0.3); animation: pulse 2s infinite; } /* Enhanced button styling */ .gradio-button { background: linear-gradient(135deg, #007bff, #0056b3) !important; border: none !important; border-radius: 10px !important; padding: 12px 24px !important; font-weight: 600 !important; font-size: 1.1em !important; transition: all 0.3s ease !important; box-shadow: 0 4px 15px rgba(0,123,255,0.3) !important; } .gradio-button:hover { transform: translateY(-2px) !important; box-shadow: 0 6px 20px rgba(0,123,255,0.4) !important; background: linear-gradient(135deg, #0056b3, #004085) !important; } /* Enhanced input styling */ .gradio-textbox, .gradio-file { border-radius: 10px !important; border: 2px solid #e9ecef !important; transition: all 0.3s ease !important; } .gradio-textbox:focus, .gradio-file:focus { border-color: #007bff !important; box-shadow: 0 0 0 3px rgba(0,123,255,0.1) !important; transform: scale(1.01) !important; } /* Enhanced plot container */ .plot-container { border-radius: 15px !important; overflow: hidden !important; box-shadow: 0 4px 15px rgba(0,0,0,0.1) !important; transition: all 0.3s ease !important; } .plot-container:hover { transform: translateY(-2px) !important; box-shadow: 0 8px 25px rgba(0,0,0,0.15) !important; } /* Footer styling with enhanced effects */ .app-footer { text-align: center; margin-top: 40px; padding: 25px; background: var(--background-fill-primary, #ffffff); border-radius: 15px; border: 2px solid var(--border-color-primary, #e0e0e0); color: var(--body-text-color, #333333); box-shadow: 0 4px 15px rgba(0,0,0,0.1); position: relative; overflow: hidden; } .app-footer::before { content: ''; position: absolute; top: 0; left: -100%; width: 100%; height: 2px; background: linear-gradient(90deg, #007bff, #28a745, #ffc107, #dc3545); animation: shimmer 3s infinite; } .disclaimer { color: #dc3545 !important; font-weight: 700; text-shadow: 0 1px 2px rgba(0,0,0,0.1); animation: pulse 3s infinite; } /* Loading animation */ .loading { position: relative; color: transparent; } .loading::after { content: 'Processing...'; position: absolute; top: 0; left: 0; color: #007bff; animation: pulse 1.5s infinite; } /* Responsive design with enhanced breakpoints */ @media (max-width: 1200px) { .gradio-container { max-width: 95% !important; } } @media (max-width: 768px) { .model-specs-grid { grid-template-columns: 1fr; gap: 15px; } .header-gradient h1 { font-size: 2.2em; } .clinical-guidelines { padding: 20px; } .model-info { padding: 20px; } .gradio-container { max-width: 100% !important; margin: 10px; } } @media (max-width: 480px) { .header-gradient h1 { font-size: 1.8em; } .clinical-guidelines { padding: 15px; } .model-info { padding: 15px; } } /* Dark mode enhancements */ @media (prefers-color-scheme: dark) { .plot-container { background: #1e1e1e !important; } .clinical-guidelines { box-shadow: 0 4px 15px rgba(255,255,255,0.1); } .model-info { box-shadow: 0 6px 20px rgba(255,255,255,0.1); } .app-footer { box-shadow: 0 4px 15px rgba(255,255,255,0.1); } } """ with gr.Blocks(css=custom_css, title="DeepStroke AI - Brain CT Analysis") as app: # Header gr.HTML("""

🧠 DeepStroke AI

Advanced Brain CT Stroke Detection System

Powered by SE-ResNeXt50 Deep Learning Architecture

""") with gr.Row(): with gr.Column(scale=1): # Model Information gr.HTML(create_model_info()) # Clinical Guidelines gr.HTML(create_clinical_guidelines()) with gr.Column(scale=2): # Main Interface gr.Markdown("## 📤 Upload Brain CT Image") # Example images section gr.Markdown("### 🖼️ Try Example Images") gr.Markdown("Click on any example below to test the stroke detection system:") with gr.Row(): with gr.Column(): # Image input image_input = gr.Image( label="Brain CT Scan", type="pil", sources=["upload", "clipboard"], height=300 ) # Analyze button analyze_btn = gr.Button( "🔍 Analyze CT Scan", variant="primary", size="lg" ) with gr.Column(): # Results section gr.Markdown("## 📊 Analysis Results") # Main prediction prediction_output = gr.Markdown( label="Prediction", value="Upload an image to see results..." ) # Metrics with gr.Row(): probability_output = gr.Textbox( label="🎯 Stroke Probability", interactive=False, container=True ) confidence_output = gr.Textbox( label="📈 Model Confidence", interactive=False, container=True ) # Detailed Analysis gr.Markdown("## 📋 Detailed Clinical Analysis") detailed_analysis_output = gr.Markdown( value="Detailed analysis will appear here after image upload..." ) # Visualization gr.Markdown("## 📊 Analysis Result") visualization_output = gr.Plot( label="Stroke Detection" ) # Create examples component for easy clicking examples = gr.Examples( examples=[ ["ExampleIMG/10189.png"], ["ExampleIMG/10300.png"], ["ExampleIMG/13447.png"], ["ExampleIMG/14343.png"], ["ExampleIMG/15614.png"], ["ExampleIMG/16760.png"], ["ExampleIMG/17023.png"] ], inputs=[image_input], outputs=[ prediction_output, probability_output, confidence_output, detailed_analysis_output, visualization_output ], fn=predict_stroke, cache_examples=False, # Disable caching to avoid index errors examples_per_page=7 ) # Footer gr.HTML(""" """) # Event handlers analyze_btn.click( fn=predict_stroke, inputs=[image_input], outputs=[ prediction_output, probability_output, confidence_output, detailed_analysis_output, visualization_output ] ) # Auto-analyze on image upload image_input.change( fn=predict_stroke, inputs=[image_input], outputs=[ prediction_output, probability_output, confidence_output, detailed_analysis_output, visualization_output ] ) return app if __name__ == "__main__": print("🚀 Starting DeepStroke AI Application...") print(f"📱 Model loaded on device: {device}") print(f"🎯 Using optimal threshold: {OPTIMAL_THRESHOLD}") # Create and launch the app app = create_gradio_app() # Launch with configuration for medical applications app.launch( server_name="0.0.0.0", # Allow external access server_port=7860, share=False, # Set to True for public sharing (not recommended for medical apps) debug=False, auth=None, # Add authentication for production use ssl_verify=True, favicon_path=None, inbrowser=True, show_error=True )