import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from torchvision.models.resnet import Bottleneck
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import io
import base64
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Model Architecture
class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class SEBottleneck(Bottleneck):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None, se_reduction=16):
super(SEBottleneck, self).__init__(inplanes, planes, stride, downsample,
groups, base_width, dilation, norm_layer)
self.se = SELayer(planes * self.expansion, reduction=se_reduction)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out = self.se(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
def get_seresnext50(num_classes=1, se_reduction=16):
"""Create SE-ResNeXt50 model"""
model = models.resnext50_32x4d(pretrained=True)
base_width = model.base_width
def replace_bottlenecks(module, se_reduction_ratio, base_width):
for name, child_module in module.named_children():
if isinstance(child_module, Bottleneck):
inplanes = child_module.conv1.in_channels
planes = child_module.conv3.out_channels // child_module.expansion
stride = child_module.stride
downsample = child_module.downsample
groups = child_module.conv2.groups
dilation = child_module.conv2.dilation[0]
new_bottleneck = SEBottleneck(
inplanes=inplanes,
planes=planes,
stride=stride,
downsample=downsample,
groups=groups,
base_width=base_width,
dilation=dilation,
se_reduction=se_reduction_ratio
)
new_bottleneck.load_state_dict(child_module.state_dict(), strict=False)
setattr(module, name, new_bottleneck)
else:
replace_bottlenecks(child_module, se_reduction_ratio, base_width)
replace_bottlenecks(model, se_reduction, base_width)
# Replace final layer for binary classification (single output)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)
return model
# Load the trained model
@torch.no_grad()
def load_model():
"""Load the trained SE-ResNeXt50 model"""
model_path = 'best_seresnext50_model.pth' # Relative path
# Create model
model = get_seresnext50(num_classes=1, se_reduction=16)
# Load checkpoint
checkpoint = torch.load(model_path, map_location=device)
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
# Remove 'module.' prefix if present
if any(key.startswith('module.') for key in state_dict.keys()):
state_dict = {key.replace('module.', ''): value for key, value in state_dict.items()}
model.load_state_dict(state_dict, strict=False)
else:
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
return model
# Initialize model
print("Loading DeepStroke Model...")
model = load_model()
print(f"Model loaded successfully on {device}")
# Image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Optimal threshold from evaluation (Youden's Index)
OPTIMAL_THRESHOLD = 0.4902
def predict_stroke(image):
"""
Predict stroke probability from brain CT image
Args:
image: PIL Image or numpy array
Returns:
tuple: (prediction_text, probability, confidence, detailed_analysis, visualization)
"""
try:
# Check if image is None
if image is None:
return "No image provided", "N/A", "N/A", "Please upload an image to analyze.", None
print(f"DEBUG: Processing image type: {type(image)}") # Debug print
# Convert to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Convert to RGB if needed
if image.mode != 'RGB':
image = image.convert('RGB')
print(f"DEBUG: Image size: {image.size}, mode: {image.mode}") # Debug print
# Preprocess image
image_tensor = transform(image).unsqueeze(0).to(device)
# Get prediction
with torch.no_grad():
output = model(image_tensor)
probability = torch.sigmoid(output).item()
# Use optimal threshold for prediction
prediction = probability >= OPTIMAL_THRESHOLD
confidence = max(probability, 1 - probability)
# Create prediction text with clinical interpretation
if prediction:
risk_level = "HIGH RISK" if probability > 0.8 else "MODERATE RISK" if probability > 0.6 else "ELEVATED RISK"
prediction_text = f"🚨 **STROKE DETECTED** - {risk_level}"
color = "#ff4444"
recommendation = "⚠️ **URGENT**: Immediate medical attention required. Contact emergency services."
else:
if probability < 0.2:
risk_level = "LOW RISK"
recommendation = "✅ **Low stroke probability detected**. Continue routine medical care."
elif probability < 0.4:
risk_level = "MILD RISK"
recommendation = "⚠️ **Mild concern**. Consider consulting with a neurologist."
else:
risk_level = "UNCERTAIN"
recommendation = "⚠️ **Borderline case**. Additional imaging or clinical assessment recommended."
prediction_text = f"✅ **NO STROKE DETECTED** - {risk_level}"
color = "#44ff44"
# Detailed analysis
detailed_analysis = f"""
**🔬 DETAILED ANALYSIS**
**Prediction:** {prediction_text}
**Stroke Probability:** {probability:.1%}
**Model Confidence:** {confidence:.1%}
**Risk Assessment:** {risk_level}
**Clinical Threshold:** {OPTIMAL_THRESHOLD:.1%} (Optimized using Youden's Index)
**Model Performance:** ROC-AUC 0.98+ on validation data
**⚕️ Clinical Recommendation:**
{recommendation}
**📋 Important Notes:**
- This AI model is for assistance only and should not replace professional medical diagnosis
- Always consult with qualified medical professionals for definitive diagnosis
- Consider patient clinical history and symptoms alongside AI predictions
"""
# Create visualization
visualization = create_prediction_visualization(probability, prediction, confidence)
return prediction_text, f"{probability:.1%}", f"{confidence:.1%}", detailed_analysis, visualization
except Exception as e:
import traceback
error_msg = f"Error during prediction: {str(e)}\n\nTraceback: {traceback.format_exc()}"
print(f"DEBUG: Prediction error - {error_msg}") # Debug print
return f"⚠️ **PREDICTION ERROR**\n\n{str(e)}", "Error", "N/A", error_msg, None
def create_prediction_visualization(probability, prediction, confidence):
"""Create a simple and clean visualization with essential information"""
try:
# Simple color scheme
color_safe = '#28a745' # Green
color_warning = '#ffc107' # Yellow
color_danger = '#dc3545' # Red
# Determine colors based on prediction
gauge_color = color_danger if prediction else color_safe
# Create a single-panel dashboard with just the probability gauge
fig = go.Figure()
# Main probability gauge - simplified
fig.add_trace(
go.Indicator(
mode="gauge+number+delta",
value=probability * 100,
title={
'text': "Stroke Probability",
'font': {'size': 18, 'family': 'Arial'}
},
number={
'font': {'size': 32, 'color': gauge_color, 'family': 'Arial Black'},
'suffix': '%'
},
delta={
'reference': OPTIMAL_THRESHOLD * 100,
'increasing': {'color': color_danger},
'decreasing': {'color': color_safe}
},
gauge={
'axis': {
'range': [0, 100],
'tickfont': {'size': 14}
},
'bar': {
'color': gauge_color,
'thickness': 0.8
},
'steps': [
{'range': [0, OPTIMAL_THRESHOLD * 100], 'color': "rgba(40, 167, 69, 0.2)"},
{'range': [OPTIMAL_THRESHOLD * 100, 100], 'color': "rgba(220, 53, 69, 0.2)"}
],
'threshold': {
'line': {'color': "#000", 'width': 3},
'thickness': 0.8,
'value': OPTIMAL_THRESHOLD * 100
}
}
)
)
# Add simple status annotation
status_text = "⚠️ STROKE DETECTED" if prediction else "✅ NO STROKE DETECTED"
status_color = color_danger if prediction else color_safe
fig.add_annotation(
x=0.5, y=0.1,
xref="paper", yref="paper",
text=f"{status_text}
" +
f"Confidence: {confidence:.0%}
" +
f"Threshold: {OPTIMAL_THRESHOLD:.0%}",
showarrow=False,
font={'size': 14, 'color': status_color},
bgcolor="rgba(255,255,255,0.9)",
bordercolor=status_color,
borderwidth=2,
borderpad=10,
xanchor="center"
)
# Simple layout configuration
fig.update_layout(
height=400, # Much smaller height
showlegend=False,
title={
'text': "🧠 Stroke Detection Result",
'x': 0.5,
'y': 0.95,
'font': {'size': 24, 'family': 'Arial Black', 'color': '#2C3E50'}
},
font={'size': 12, 'family': 'Arial'},
margin=dict(t=60, b=40, l=40, r=40),
plot_bgcolor='white',
paper_bgcolor='white'
)
return fig
except Exception as e:
import traceback
print(f"DEBUG: Visualization error - {str(e)}\n{traceback.format_exc()}")
# Return a simple figure on error
fig = go.Figure()
fig.add_annotation(
x=0.5, y=0.5,
text=f"Visualization Error: {str(e)}",
showarrow=False,
font={'size': 16, 'color': 'red'}
)
return fig
def create_model_info():
"""Create enhanced model information display with interactive elements"""
info_html = """
🚀 Latest Model Version 1.0 | Trained on 10,000+ validated NON-CONTRAST brain CT scans
⚠️ ONLY for non-contrast brain CT imaging
🚨 This model will FAIL on:
⚡ Using inappropriate image types will produce unreliable results!
Advanced Brain CT Stroke Detection System
Powered by SE-ResNeXt50 Deep Learning Architecture