DeepStroke / app.py
huseyincavus's picture
Update app.py
3803121 verified
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from torchvision.models.resnet import Bottleneck
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import io
import base64
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Model Architecture
class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class SEBottleneck(Bottleneck):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None, se_reduction=16):
super(SEBottleneck, self).__init__(inplanes, planes, stride, downsample,
groups, base_width, dilation, norm_layer)
self.se = SELayer(planes * self.expansion, reduction=se_reduction)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out = self.se(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
def get_seresnext50(num_classes=1, se_reduction=16):
"""Create SE-ResNeXt50 model"""
model = models.resnext50_32x4d(pretrained=True)
base_width = model.base_width
def replace_bottlenecks(module, se_reduction_ratio, base_width):
for name, child_module in module.named_children():
if isinstance(child_module, Bottleneck):
inplanes = child_module.conv1.in_channels
planes = child_module.conv3.out_channels // child_module.expansion
stride = child_module.stride
downsample = child_module.downsample
groups = child_module.conv2.groups
dilation = child_module.conv2.dilation[0]
new_bottleneck = SEBottleneck(
inplanes=inplanes,
planes=planes,
stride=stride,
downsample=downsample,
groups=groups,
base_width=base_width,
dilation=dilation,
se_reduction=se_reduction_ratio
)
new_bottleneck.load_state_dict(child_module.state_dict(), strict=False)
setattr(module, name, new_bottleneck)
else:
replace_bottlenecks(child_module, se_reduction_ratio, base_width)
replace_bottlenecks(model, se_reduction, base_width)
# Replace final layer for binary classification (single output)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)
return model
# Load the trained model
@torch.no_grad()
def load_model():
"""Load the trained SE-ResNeXt50 model"""
model_path = 'best_seresnext50_model.pth' # Relative path
# Create model
model = get_seresnext50(num_classes=1, se_reduction=16)
# Load checkpoint
checkpoint = torch.load(model_path, map_location=device)
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
# Remove 'module.' prefix if present
if any(key.startswith('module.') for key in state_dict.keys()):
state_dict = {key.replace('module.', ''): value for key, value in state_dict.items()}
model.load_state_dict(state_dict, strict=False)
else:
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
return model
# Initialize model
print("Loading DeepStroke Model...")
model = load_model()
print(f"Model loaded successfully on {device}")
# Image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Optimal threshold from evaluation (Youden's Index)
OPTIMAL_THRESHOLD = 0.4902
def predict_stroke(image):
"""
Predict stroke probability from brain CT image
Args:
image: PIL Image or numpy array
Returns:
tuple: (prediction_text, probability, confidence, detailed_analysis, visualization)
"""
try:
# Check if image is None
if image is None:
return "No image provided", "N/A", "N/A", "Please upload an image to analyze.", None
print(f"DEBUG: Processing image type: {type(image)}") # Debug print
# Convert to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Convert to RGB if needed
if image.mode != 'RGB':
image = image.convert('RGB')
print(f"DEBUG: Image size: {image.size}, mode: {image.mode}") # Debug print
# Preprocess image
image_tensor = transform(image).unsqueeze(0).to(device)
# Get prediction
with torch.no_grad():
output = model(image_tensor)
probability = torch.sigmoid(output).item()
# Use optimal threshold for prediction
prediction = probability >= OPTIMAL_THRESHOLD
confidence = max(probability, 1 - probability)
# Create prediction text with clinical interpretation
if prediction:
risk_level = "HIGH RISK" if probability > 0.8 else "MODERATE RISK" if probability > 0.6 else "ELEVATED RISK"
prediction_text = f"🚨 **STROKE DETECTED** - {risk_level}"
color = "#ff4444"
recommendation = "⚠️ **URGENT**: Immediate medical attention required. Contact emergency services."
else:
if probability < 0.2:
risk_level = "LOW RISK"
recommendation = "✅ **Low stroke probability detected**. Continue routine medical care."
elif probability < 0.4:
risk_level = "MILD RISK"
recommendation = "⚠️ **Mild concern**. Consider consulting with a neurologist."
else:
risk_level = "UNCERTAIN"
recommendation = "⚠️ **Borderline case**. Additional imaging or clinical assessment recommended."
prediction_text = f"✅ **NO STROKE DETECTED** - {risk_level}"
color = "#44ff44"
# Detailed analysis
detailed_analysis = f"""
**🔬 DETAILED ANALYSIS**
**Prediction:** {prediction_text}
**Stroke Probability:** {probability:.1%}
**Model Confidence:** {confidence:.1%}
**Risk Assessment:** {risk_level}
**Clinical Threshold:** {OPTIMAL_THRESHOLD:.1%} (Optimized using Youden's Index)
**Model Performance:** ROC-AUC 0.98+ on validation data
**⚕️ Clinical Recommendation:**
{recommendation}
**📋 Important Notes:**
- This AI model is for assistance only and should not replace professional medical diagnosis
- Always consult with qualified medical professionals for definitive diagnosis
- Consider patient clinical history and symptoms alongside AI predictions
"""
# Create visualization
visualization = create_prediction_visualization(probability, prediction, confidence)
return prediction_text, f"{probability:.1%}", f"{confidence:.1%}", detailed_analysis, visualization
except Exception as e:
import traceback
error_msg = f"Error during prediction: {str(e)}\n\nTraceback: {traceback.format_exc()}"
print(f"DEBUG: Prediction error - {error_msg}") # Debug print
return f"⚠️ **PREDICTION ERROR**\n\n{str(e)}", "Error", "N/A", error_msg, None
def create_prediction_visualization(probability, prediction, confidence):
"""Create a simple and clean visualization with essential information"""
try:
# Simple color scheme
color_safe = '#28a745' # Green
color_warning = '#ffc107' # Yellow
color_danger = '#dc3545' # Red
# Determine colors based on prediction
gauge_color = color_danger if prediction else color_safe
# Create a single-panel dashboard with just the probability gauge
fig = go.Figure()
# Main probability gauge - simplified
fig.add_trace(
go.Indicator(
mode="gauge+number+delta",
value=probability * 100,
title={
'text': "Stroke Probability",
'font': {'size': 18, 'family': 'Arial'}
},
number={
'font': {'size': 32, 'color': gauge_color, 'family': 'Arial Black'},
'suffix': '%'
},
delta={
'reference': OPTIMAL_THRESHOLD * 100,
'increasing': {'color': color_danger},
'decreasing': {'color': color_safe}
},
gauge={
'axis': {
'range': [0, 100],
'tickfont': {'size': 14}
},
'bar': {
'color': gauge_color,
'thickness': 0.8
},
'steps': [
{'range': [0, OPTIMAL_THRESHOLD * 100], 'color': "rgba(40, 167, 69, 0.2)"},
{'range': [OPTIMAL_THRESHOLD * 100, 100], 'color': "rgba(220, 53, 69, 0.2)"}
],
'threshold': {
'line': {'color': "#000", 'width': 3},
'thickness': 0.8,
'value': OPTIMAL_THRESHOLD * 100
}
}
)
)
# Add simple status annotation
status_text = "⚠️ STROKE DETECTED" if prediction else "✅ NO STROKE DETECTED"
status_color = color_danger if prediction else color_safe
fig.add_annotation(
x=0.5, y=0.1,
xref="paper", yref="paper",
text=f"<b style='color:{status_color};font-size:16px'>{status_text}</b><br>" +
f"<span style='font-size:12px'>Confidence: {confidence:.0%}</span><br>" +
f"<span style='font-size:11px'>Threshold: {OPTIMAL_THRESHOLD:.0%}</span>",
showarrow=False,
font={'size': 14, 'color': status_color},
bgcolor="rgba(255,255,255,0.9)",
bordercolor=status_color,
borderwidth=2,
borderpad=10,
xanchor="center"
)
# Simple layout configuration
fig.update_layout(
height=400, # Much smaller height
showlegend=False,
title={
'text': "🧠 Stroke Detection Result",
'x': 0.5,
'y': 0.95,
'font': {'size': 24, 'family': 'Arial Black', 'color': '#2C3E50'}
},
font={'size': 12, 'family': 'Arial'},
margin=dict(t=60, b=40, l=40, r=40),
plot_bgcolor='white',
paper_bgcolor='white'
)
return fig
except Exception as e:
import traceback
print(f"DEBUG: Visualization error - {str(e)}\n{traceback.format_exc()}")
# Return a simple figure on error
fig = go.Figure()
fig.add_annotation(
x=0.5, y=0.5,
text=f"Visualization Error: {str(e)}",
showarrow=False,
font={'size': 16, 'color': 'red'}
)
return fig
def create_model_info():
"""Create enhanced model information display with interactive elements"""
info_html = """
<div class="model-info">
<h2>🧠 DeepStroke AI - Advanced Brain CT Stroke Detection</h2>
<div class="model-specs-grid">
<div>
<h3>🏗️ Model Architecture</h3>
<ul>
<li><strong>🔬 Network:</strong> SE-ResNeXt50</li>
<li><strong>📐 Input Size:</strong> 224×224 RGB</li>
<li><strong>⚙️ Parameters:</strong> ~25M trained</li>
<li><strong>🎯 Threshold:</strong> 49.02% (Optimized)</li>
<li><strong>🧮 SE Ratio:</strong> 16 (Attention)</li>
</ul>
</div>
<div>
<h3>📊 Performance Metrics</h3>
<ul>
<li><strong>🎯 ROC-AUC:</strong> 0.98+ (Excellent)</li>
<li><strong>🔍 Sensitivity:</strong> High stroke detection</li>
<li><strong>✅ Specificity:</strong> Low false alarms</li>
<li><strong>📋 Validation:</strong> External datasets</li>
<li><strong>⚡ Speed:</strong> <1s inference</li>
</ul>
</div>
</div>
<div style="margin-top: 20px; padding-top: 15px; border-top: 2px solid rgba(255,255,255,0.3);">
<p style="text-align: center; margin: 0; font-size: 1.1em;">
<strong>🚀 Latest Model Version 1.0</strong> |
<em>Trained on 10,000+ validated NON-CONTRAST brain CT scans</em>
</p>
<p style="text-align: center; margin: 5px 0 0 0; font-size: 0.95em; color: rgba(255,255,255,0.9);">
⚠️ <strong>ONLY for non-contrast brain CT imaging</strong>
</p>
</div>
</div>
"""
return info_html
def create_clinical_guidelines():
"""Create enhanced clinical guidelines display with improved interactivity"""
guidelines_html = """
<div class="clinical-guidelines">
<h3 class="guidelines-title">⚕️ Clinical Usage Guidelines & Safety Protocols</h3>
<div style="margin: 20px 0;">
<h4 class="critical-section">🚨 CRITICAL SAFETY REMINDERS</h4>
<ul class="critical-list">
<li><strong>⚠️ AI ASSISTANCE ONLY</strong> - This tool provides diagnostic support but cannot replace professional medical judgment</li>
<li><strong>👨‍⚕️ ALWAYS CONSULT PHYSICIANS</strong> - Qualified medical professionals must make final diagnostic and treatment decisions</li>
<li><strong>⏰ TIME-CRITICAL CASES</strong> - In suspected acute stroke, follow standard emergency protocols regardless of AI output</li>
<li><strong>🧠 NON-CONTRAST BRAIN CT ONLY</strong> - This model was trained exclusively on non-contrast brain CT scans and will fail on other imaging types</li>
</ul>
</div>
<div style="margin: 20px 0;">
<h4 class="best-practices-section">✅ BEST PRACTICES</h4>
<ul class="guidelines-list">
<li><strong>🔍 Image Quality:</strong> Ensure CT scans have adequate contrast and clear anatomical landmarks</li>
<li><strong>🎯 Threshold:</strong> 49% threshold optimized for balanced sensitivity/specificity</li>
<li><strong>🔄 Cross-validation:</strong> Compare AI findings with clinical assessments</li>
<li><strong>👥 Team Approach:</strong> Involve radiologists and neurologists in complex cases</li>
</ul>
</div>
<div style="margin: 20px 0; padding: 15px; background: rgba(220,53,69,0.1); border-radius: 10px; border-left: 4px solid #dc3545;">
<h4 style="color: #dc3545; margin-top: 0;">⚠️ IMAGING LIMITATIONS</h4>
<p style="margin-bottom: 8px; font-weight: 600; color: #dc3545;">
🚨 <strong>This model will FAIL on:</strong>
</p>
<ul style="margin: 0; color: #495057;">
<li><strong>Contrast-enhanced CT scans, MRI images, X-rays</strong></li>
<li><strong>Other organ imaging</strong> (chest, abdomen, spine, etc.)</li>
<li><strong>Pediatric scans or post-surgical images</strong></li>
</ul>
<p style="margin-top: 10px; font-weight: 600; color: #dc3545;">
⚡ <strong>Using inappropriate image types will produce unreliable results!</strong>
</p>
</div>
</div>
"""
return guidelines_html
# Create Gradio Interface
def create_gradio_app():
"""Create the main Gradio application with improved styling"""
# Custom CSS for medical-grade styling with dark mode support and animations
custom_css = """
.gradio-container {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
max-width: 1400px !important;
margin: auto;
animation: fadeIn 0.5s ease-in;
}
@keyframes fadeIn {
from { opacity: 0; transform: translateY(10px); }
to { opacity: 1; transform: translateY(0); }
}
@keyframes pulse {
0% { transform: scale(1); }
50% { transform: scale(1.02); }
100% { transform: scale(1); }
}
@keyframes shimmer {
0% { background-position: -200px 0; }
100% { background-position: 200px 0; }
}
/* Clinical Guidelines Styling - Dark Mode Compatible */
.clinical-guidelines {
background: var(--background-fill-primary, #ffffff);
border: 2px solid var(--border-color-primary, #e0e0e0);
border-left: 5px solid #007bff;
padding: 25px;
border-radius: 15px;
margin: 15px 0;
box-shadow: 0 4px 15px rgba(0,0,0,0.1);
color: var(--body-text-color, #000000);
transition: all 0.3s ease;
position: relative;
overflow: hidden;
}
.clinical-guidelines:hover {
transform: translateY(-2px);
box-shadow: 0 6px 20px rgba(0,0,0,0.15);
}
.clinical-guidelines::before {
content: '';
position: absolute;
top: -2px;
left: -2px;
right: -2px;
bottom: -2px;
background: linear-gradient(45deg, #007bff, #28a745, #ffc107, #dc3545);
border-radius: 15px;
z-index: -1;
opacity: 0;
transition: opacity 0.3s ease;
}
.clinical-guidelines:hover::before {
opacity: 0.1;
}
.guidelines-title {
color: #007bff !important;
margin-top: 0 !important;
font-weight: 700;
font-size: 1.3em;
text-shadow: 0 1px 2px rgba(0,0,0,0.1);
}
.critical-section {
color: #dc3545 !important;
font-weight: 700;
margin-bottom: 10px;
font-size: 1.1em;
display: flex;
align-items: center;
gap: 8px;
}
.best-practices-section {
color: #28a745 !important;
font-weight: 700;
margin-bottom: 10px;
font-size: 1.1em;
display: flex;
align-items: center;
gap: 8px;
}
.requirements-section {
color: #fd7e14 !important;
font-weight: 700;
margin-bottom: 10px;
font-size: 1.1em;
display: flex;
align-items: center;
gap: 8px;
}
.critical-list {
color: #dc3545 !important;
font-weight: 600;
line-height: 1.6;
}
.critical-list li {
margin-bottom: 8px;
padding-left: 5px;
border-left: 3px solid #dc3545;
padding-left: 10px;
margin-left: 5px;
}
.guidelines-list {
color: var(--body-text-color, #333333) !important;
opacity: 0.9;
line-height: 1.6;
}
.guidelines-list li {
margin-bottom: 6px;
padding-left: 5px;
transition: all 0.2s ease;
}
.guidelines-list li:hover {
transform: translateX(5px);
color: #007bff !important;
}
/* Model Info Box - Enhanced with gradients and animations */
.model-info {
background: linear-gradient(135deg, #667eea 0%, #764ba2 50%, #667eea 100%);
background-size: 200% 200%;
animation: gradientShift 6s ease infinite;
padding: 25px;
border-radius: 15px;
color: white;
margin: 15px 0;
box-shadow: 0 6px 20px rgba(0,0,0,0.2);
position: relative;
overflow: hidden;
}
@keyframes gradientShift {
0% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
100% { background-position: 0% 50%; }
}
.model-info::before {
content: '';
position: absolute;
top: 0;
left: -100%;
width: 100%;
height: 100%;
background: linear-gradient(90deg, transparent, rgba(255,255,255,0.2), transparent);
animation: shimmer 2s infinite;
}
.model-info:hover {
transform: scale(1.02);
transition: transform 0.3s ease;
}
.model-info h2 {
margin-top: 0;
text-align: center;
text-shadow: 0 2px 4px rgba(0,0,0,0.3);
font-size: 1.8em;
margin-bottom: 20px;
}
.model-specs-grid {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 25px;
margin-top: 20px;
}
.model-specs-grid h3 {
margin-bottom: 15px;
font-size: 1.2em;
border-bottom: 2px solid rgba(255,255,255,0.3);
padding-bottom: 8px;
}
.model-specs-grid ul {
list-style: none;
padding: 0;
}
.model-specs-grid li {
margin-bottom: 8px;
padding: 8px;
background: rgba(255,255,255,0.1);
border-radius: 8px;
transition: all 0.3s ease;
}
.model-specs-grid li:hover {
background: rgba(255,255,255,0.2);
transform: translateX(5px);
}
/* Header gradient with enhanced effects */
.header-gradient {
background: linear-gradient(135deg, #667eea 0%, #764ba2 50%, #667eea 100%);
background-size: 200% 200%;
animation: gradientShift 8s ease infinite;
color: white;
padding: 30px;
border-radius: 15px;
text-align: center;
margin-bottom: 25px;
box-shadow: 0 6px 25px rgba(0,0,0,0.2);
position: relative;
overflow: hidden;
}
.header-gradient::before {
content: '';
position: absolute;
top: 0;
left: -100%;
width: 100%;
height: 100%;
background: linear-gradient(90deg, transparent, rgba(255,255,255,0.1), transparent);
animation: shimmer 3s infinite;
}
.header-gradient h1 {
margin: 0;
font-size: 2.8em;
text-shadow: 0 3px 6px rgba(0,0,0,0.3);
animation: pulse 2s infinite;
}
/* Enhanced button styling */
.gradio-button {
background: linear-gradient(135deg, #007bff, #0056b3) !important;
border: none !important;
border-radius: 10px !important;
padding: 12px 24px !important;
font-weight: 600 !important;
font-size: 1.1em !important;
transition: all 0.3s ease !important;
box-shadow: 0 4px 15px rgba(0,123,255,0.3) !important;
}
.gradio-button:hover {
transform: translateY(-2px) !important;
box-shadow: 0 6px 20px rgba(0,123,255,0.4) !important;
background: linear-gradient(135deg, #0056b3, #004085) !important;
}
/* Enhanced input styling */
.gradio-textbox, .gradio-file {
border-radius: 10px !important;
border: 2px solid #e9ecef !important;
transition: all 0.3s ease !important;
}
.gradio-textbox:focus, .gradio-file:focus {
border-color: #007bff !important;
box-shadow: 0 0 0 3px rgba(0,123,255,0.1) !important;
transform: scale(1.01) !important;
}
/* Enhanced plot container */
.plot-container {
border-radius: 15px !important;
overflow: hidden !important;
box-shadow: 0 4px 15px rgba(0,0,0,0.1) !important;
transition: all 0.3s ease !important;
}
.plot-container:hover {
transform: translateY(-2px) !important;
box-shadow: 0 8px 25px rgba(0,0,0,0.15) !important;
}
/* Footer styling with enhanced effects */
.app-footer {
text-align: center;
margin-top: 40px;
padding: 25px;
background: var(--background-fill-primary, #ffffff);
border-radius: 15px;
border: 2px solid var(--border-color-primary, #e0e0e0);
color: var(--body-text-color, #333333);
box-shadow: 0 4px 15px rgba(0,0,0,0.1);
position: relative;
overflow: hidden;
}
.app-footer::before {
content: '';
position: absolute;
top: 0;
left: -100%;
width: 100%;
height: 2px;
background: linear-gradient(90deg, #007bff, #28a745, #ffc107, #dc3545);
animation: shimmer 3s infinite;
}
.disclaimer {
color: #dc3545 !important;
font-weight: 700;
text-shadow: 0 1px 2px rgba(0,0,0,0.1);
animation: pulse 3s infinite;
}
/* Loading animation */
.loading {
position: relative;
color: transparent;
}
.loading::after {
content: 'Processing...';
position: absolute;
top: 0;
left: 0;
color: #007bff;
animation: pulse 1.5s infinite;
}
/* Responsive design with enhanced breakpoints */
@media (max-width: 1200px) {
.gradio-container {
max-width: 95% !important;
}
}
@media (max-width: 768px) {
.model-specs-grid {
grid-template-columns: 1fr;
gap: 15px;
}
.header-gradient h1 {
font-size: 2.2em;
}
.clinical-guidelines {
padding: 20px;
}
.model-info {
padding: 20px;
}
.gradio-container {
max-width: 100% !important;
margin: 10px;
}
}
@media (max-width: 480px) {
.header-gradient h1 {
font-size: 1.8em;
}
.clinical-guidelines {
padding: 15px;
}
.model-info {
padding: 15px;
}
}
/* Dark mode enhancements */
@media (prefers-color-scheme: dark) {
.plot-container {
background: #1e1e1e !important;
}
.clinical-guidelines {
box-shadow: 0 4px 15px rgba(255,255,255,0.1);
}
.model-info {
box-shadow: 0 6px 20px rgba(255,255,255,0.1);
}
.app-footer {
box-shadow: 0 4px 15px rgba(255,255,255,0.1);
}
}
"""
with gr.Blocks(css=custom_css, title="DeepStroke AI - Brain CT Analysis") as app:
# Header
gr.HTML("""
<div class="header-gradient">
<h1 style="margin: 0; font-size: 2.5em;">🧠 DeepStroke AI</h1>
<p style="margin: 10px 0 0 0; font-size: 1.2em;">Advanced Brain CT Stroke Detection System</p>
<p style="margin: 5px 0 0 0; opacity: 0.9;">Powered by SE-ResNeXt50 Deep Learning Architecture</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
# Model Information
gr.HTML(create_model_info())
# Clinical Guidelines
gr.HTML(create_clinical_guidelines())
with gr.Column(scale=2):
# Main Interface
gr.Markdown("## 📤 Upload Brain CT Image")
# Example images section
gr.Markdown("### 🖼️ Try Example Images")
gr.Markdown("Click on any example below to test the stroke detection system:")
with gr.Row():
with gr.Column():
# Image input
image_input = gr.Image(
label="Brain CT Scan",
type="pil",
sources=["upload", "clipboard"],
height=300
)
# Analyze button
analyze_btn = gr.Button(
"🔍 Analyze CT Scan",
variant="primary",
size="lg"
)
with gr.Column():
# Results section
gr.Markdown("## 📊 Analysis Results")
# Main prediction
prediction_output = gr.Markdown(
label="Prediction",
value="Upload an image to see results..."
)
# Metrics
with gr.Row():
probability_output = gr.Textbox(
label="🎯 Stroke Probability",
interactive=False,
container=True
)
confidence_output = gr.Textbox(
label="📈 Model Confidence",
interactive=False,
container=True
)
# Detailed Analysis
gr.Markdown("## 📋 Detailed Clinical Analysis")
detailed_analysis_output = gr.Markdown(
value="Detailed analysis will appear here after image upload..."
)
# Visualization
gr.Markdown("## 📊 Analysis Result")
visualization_output = gr.Plot(
label="Stroke Detection"
)
# Create examples component for easy clicking
examples = gr.Examples(
examples=[
["ExampleIMG/10189.png"],
["ExampleIMG/10300.png"],
["ExampleIMG/13447.png"],
["ExampleIMG/14343.png"],
["ExampleIMG/15614.png"],
["ExampleIMG/16760.png"],
["ExampleIMG/17023.png"]
],
inputs=[image_input],
outputs=[
prediction_output,
probability_output,
confidence_output,
detailed_analysis_output,
visualization_output
],
fn=predict_stroke,
cache_examples=False, # Disable caching to avoid index errors
examples_per_page=7
)
# Footer
gr.HTML("""
<div class="app-footer">
<p>
<strong>DeepStroke AI v1.0</strong> |
Developed for Brain CT Stroke Detection |
<span class="disclaimer">For Research Only</span>
</p>
<p style="margin: 5px 0 0 0; font-size: 0.9em;">
Always consult with qualified medical professionals for definitive diagnosis and treatment decisions.
</p>
</div>
""")
# Event handlers
analyze_btn.click(
fn=predict_stroke,
inputs=[image_input],
outputs=[
prediction_output,
probability_output,
confidence_output,
detailed_analysis_output,
visualization_output
]
)
# Auto-analyze on image upload
image_input.change(
fn=predict_stroke,
inputs=[image_input],
outputs=[
prediction_output,
probability_output,
confidence_output,
detailed_analysis_output,
visualization_output
]
)
return app
if __name__ == "__main__":
print("🚀 Starting DeepStroke AI Application...")
print(f"📱 Model loaded on device: {device}")
print(f"🎯 Using optimal threshold: {OPTIMAL_THRESHOLD}")
# Create and launch the app
app = create_gradio_app()
# Launch with configuration for medical applications
app.launch(
server_name="0.0.0.0", # Allow external access
server_port=7860,
share=False, # Set to True for public sharing (not recommended for medical apps)
debug=False,
auth=None, # Add authentication for production use
ssl_verify=True,
favicon_path=None,
inbrowser=True,
show_error=True
)