brain-tumor / app4.py
ishans24's picture
Rename app.py to app4.py
e21d7ad verified
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# Define class names
CLASS_NAMES = ['Glioma', 'Meningioma', 'No Tumor', 'Pituitary']
# Model configurations
MODELS = {
"MobileNetV2": {
"path": "best_mobilenetv2.keras",
"preprocessing": tf.keras.applications.mobilenet_v2.preprocess_input,
"img_size": (224, 224),
"description": "MobileNetV2 - Lightweight and efficient"
},
"DenseNet121": {
"path": "best_densenet121.keras",
"preprocessing": tf.keras.applications.densenet.preprocess_input,
"img_size": (224, 224),
"description": "DenseNet121 - Dense connections for better gradient flow"
},
"EfficientNetV2S": {
"path": "best_efficientnetv2s.keras",
"preprocessing": tf.keras.applications.efficientnet_v2.preprocess_input,
"img_size": (224, 224),
"description": "EfficientNetV2S - State-of-the-art efficiency"
}
}
# Load all models at startup
loaded_models = {}
def load_model(model_name):
"""Load a model if not already loaded"""
if model_name not in loaded_models:
try:
model_path = MODELS[model_name]["path"]
loaded_models[model_name] = tf.keras.models.load_model(model_path)
print(f"βœ“ Loaded {model_name}")
except Exception as e:
print(f"βœ— Failed to load {model_name}: {str(e)}")
return None
return loaded_models[model_name]
# Preload all models
for model_name in MODELS.keys():
load_model(model_name)
def preprocess_image(image, model_name):
"""Preprocess image according to model requirements"""
img_size = MODELS[model_name]["img_size"]
preprocessing_fn = MODELS[model_name]["preprocessing"]
# Resize image
img = image.resize(img_size)
# Convert to array
img_array = np.array(img)
# Convert to RGB if grayscale
if len(img_array.shape) == 2:
img_array = np.stack([img_array] * 3, axis=-1)
elif img_array.shape[-1] == 4: # RGBA
img_array = img_array[..., :3]
# Add batch dimension
img_array = np.expand_dims(img_array, axis=0)
# Apply model-specific preprocessing
img_array = preprocessing_fn(img_array)
return img_array
def create_prediction_plot(predictions, class_names):
"""Create an interactive bar plot of predictions"""
fig = go.Figure(data=[
go.Bar(
x=predictions,
y=class_names,
orientation='h',
marker=dict(
color=predictions,
colorscale='RdYlGn',
showscale=True,
colorbar=dict(title="Confidence")
),
text=[f'{p:.2%}' for p in predictions],
textposition='auto',
)
])
fig.update_layout(
title="Prediction Confidence Distribution",
xaxis_title="Confidence Score",
yaxis_title="Tumor Type",
height=400,
xaxis=dict(range=[0, 1]),
template="plotly_white"
)
return fig
def create_comparison_plot(all_predictions, model_names):
"""Create a grouped bar plot comparing predictions across models"""
fig = go.Figure()
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A']
for i, class_name in enumerate(CLASS_NAMES):
fig.add_trace(go.Bar(
name=class_name,
x=model_names,
y=[pred[i] for pred in all_predictions],
marker_color=colors[i]
))
fig.update_layout(
title="Model Comparison - Prediction Confidence",
xaxis_title="Model",
yaxis_title="Confidence Score",
barmode='group',
height=450,
template="plotly_white",
legend=dict(
title="Tumor Type",
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1
)
)
return fig
def predict(image, model_name):
"""Make prediction using selected model"""
if image is None:
return "Please upload an image", None, None
try:
# Load model
model = load_model(model_name)
if model is None:
return f"Error: Could not load {model_name}", None, None
# Preprocess image
processed_img = preprocess_image(image, model_name)
# Make prediction
predictions = model.predict(processed_img, verbose=0)[0]
# Get predicted class
predicted_idx = np.argmax(predictions)
predicted_class = CLASS_NAMES[predicted_idx]
confidence = predictions[predicted_idx]
# Create result text
result_text = f"""
### πŸ”¬ Diagnosis Result ({model_name})
**Predicted Class:** {predicted_class}
**Confidence:** {confidence:.2%}
#### All Class Probabilities:
"""
for i, (class_name, prob) in enumerate(zip(CLASS_NAMES, predictions)):
emoji = "🎯" if i == predicted_idx else "πŸ“Š"
result_text += f"\n{emoji} **{class_name}:** {prob:.2%}"
# Create visualization
plot = create_prediction_plot(predictions, CLASS_NAMES)
return result_text, plot, predictions
except Exception as e:
return f"Error during prediction: {str(e)}", None, None
def compare_models(image):
"""Compare predictions across all models"""
if image is None:
return "Please upload an image", None
try:
all_predictions = []
model_names = []
result_text = "### πŸ“Š Model Comparison Results\n\n"
for model_name in MODELS.keys():
model = load_model(model_name)
if model is not None:
processed_img = preprocess_image(image, model_name)
predictions = model.predict(processed_img, verbose=0)[0]
all_predictions.append(predictions)
model_names.append(model_name)
predicted_idx = np.argmax(predictions)
predicted_class = CLASS_NAMES[predicted_idx]
confidence = predictions[predicted_idx]
result_text += f"**{model_name}:** {predicted_class} ({confidence:.2%})\n\n"
if len(all_predictions) > 0:
plot = create_comparison_plot(all_predictions, model_names)
return result_text, plot
else:
return "Error: No models could be loaded", None
except Exception as e:
return f"Error during comparison: {str(e)}", None
# Custom CSS for better styling
custom_css = """
.gradio-container {
font-family: 'Arial', sans-serif;
}
.gr-button-primary {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
border: none !important;
}
.gr-button-secondary {
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%) !important;
border: none !important;
}
.output-markdown {
background-color: #f8f9fa;
padding: 20px;
border-radius: 10px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
"""
# Create Gradio interface
with gr.Blocks(css=custom_css, title="Brain Tumor Classification") as app:
gr.Markdown(
"""
# 🧠 Brain Tumor MRI Classification System
Upload an MRI scan to classify brain tumors using state-of-the-art deep learning models.
**Tumor Types:** Glioma, Meningioma, No Tumor, Pituitary
"""
)
with gr.Tabs():
# Single Model Prediction Tab
with gr.TabItem("πŸ” Single Model Prediction"):
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload MRI Scan")
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value="MobileNetV2",
label="Select Model",
info="Choose which model to use for prediction"
)
# Model description
model_info = gr.Markdown(MODELS["MobileNetV2"]["description"])
def update_model_info(model_name):
return MODELS[model_name]["description"]
model_dropdown.change(
fn=update_model_info,
inputs=model_dropdown,
outputs=model_info
)
predict_btn = gr.Button("πŸ”¬ Analyze", variant="primary")
with gr.Column(scale=2):
output_text = gr.Markdown(label="Prediction Results")
output_plot = gr.Plot(label="Confidence Distribution")
predict_btn.click(
fn=predict,
inputs=[input_image, model_dropdown],
outputs=[output_text, output_plot, gr.State()]
)
# Model Comparison Tab
with gr.TabItem("πŸ“Š Compare All Models"):
with gr.Row():
with gr.Column(scale=1):
compare_image = gr.Image(type="pil", label="Upload MRI Scan")
compare_btn = gr.Button("βš–οΈ Compare Models", variant="secondary")
gr.Markdown(
"""
### Available Models:
- **MobileNetV2**: Fast and efficient
- **DenseNet121**: Deep dense connections
- **EfficientNetV2S**: Latest efficiency improvements
"""
)
with gr.Column(scale=2):
compare_text = gr.Markdown(label="Comparison Results")
compare_plot = gr.Plot(label="Model Comparison Visualization")
compare_btn.click(
fn=compare_models,
inputs=compare_image,
outputs=[compare_text, compare_plot]
)
# Information Tab
with gr.TabItem("ℹ️ About"):
gr.Markdown(
"""
## About This Application
This application uses deep learning models trained on brain MRI scans to classify different types of brain tumors.
### Tumor Types:
1. **Glioma**: A tumor that occurs in the brain and spinal cord
2. **Meningioma**: A tumor that forms on membranes covering the brain and spinal cord
3. **Pituitary**: A tumor in the pituitary gland
4. **No Tumor**: Healthy brain tissue
### Models:
- **MobileNetV2**: Lightweight architecture ideal for mobile deployment
- **DenseNet121**: Dense connections improve feature propagation
- **EfficientNetV2S**: Optimized for both accuracy and efficiency
### Image Requirements:
- Format: PNG, JPG, JPEG
- The models automatically resize images to 224x224 pixels
- Grayscale images are automatically converted to RGB
### Performance:
All models achieve >99% test accuracy on the brain tumor dataset.
---
**Note**: This is a demonstration system and should not be used for actual medical diagnosis.
Always consult with qualified healthcare professionals for medical advice.
"""
)
# Launch the app
if __name__ == "__main__":
app.launch(share=True)