Kaynaaf's picture
Create app.py
ecfd876 verified
raw
history blame
9.16 kB
import gradio as gr
import matplotlib.pyplot as plt
import keras
from matplotlib.colors import LinearSegmentedColormap
from PIL import Image
import io
import glob
import os
import pickle
from tqdm import tqdm
import numpy as np
with open('model.pkl', 'rb') as f:
model = pickle.load(f)
class_names = ['glioma', 'meningioma', 'notumor', 'pituitary']
def create_custom_colormap():
colors = [(0, 0, 0, 0), (1, 0, 0, 1)]
return LinearSegmentedColormap.from_list('custom', colors)
def create_custom_colormap():
colors = [(0, 0, 0, 0), (1, 0, 0, 1)]
return LinearSegmentedColormap.from_list('custom', colors)
def occlusion_sensitivity(model, img_array, class_index, patch_size=25, stride=5, occlusion_value=0, progress=gr.Progress()):
sensitivity_map = np.zeros((img_array.shape[0], img_array.shape[1]))
original_pred = model.predict(np.expand_dims(img_array, axis=0), verbose=0)[0]
original_prob = original_pred[class_index]
n_steps_h = (img_array.shape[0] - patch_size) // stride + 1
n_steps_w = (img_array.shape[1] - patch_size) // stride + 1
total_steps = n_steps_h * n_steps_w
current_step = 0
for h in range(n_steps_h):
for w in range(n_steps_w):
h_start = h * stride
w_start = w * stride
h_end = min(h_start + patch_size, img_array.shape[0])
w_end = min(w_start + patch_size, img_array.shape[1])
occluded_img = img_array.copy()
occluded_img[h_start:h_end, w_start:w_end, :] = occlusion_value
pred = model.predict(np.expand_dims(occluded_img, axis=0), verbose=0)[0]
prob = pred[class_index]
sensitivity = original_prob - prob
sensitivity_map[h_start:h_end, w_start:w_end] += sensitivity
current_step += 1
progress(current_step / total_steps, desc=f"Analyzing sensitivity: {current_step}/{total_steps}")
return sensitivity_map, original_prob
def create_sensitivity_visualizations(image, sensitivity_map, predicted_class, confidence):
# Original image
fig1, ax1 = plt.subplots(figsize=(6, 6))
ax1.imshow(image, cmap='gray')
ax1.set_title(f'Original Image\n{class_names[predicted_class]} ({confidence:.1%})', fontsize=12, fontweight='bold')
ax1.axis('off')
plt.tight_layout()
buf1 = io.BytesIO()
plt.savefig(buf1, format='png', dpi=150, bbox_inches='tight')
buf1.seek(0)
original_img = Image.open(buf1)
plt.close()
# Sensitivity map
fig2, ax2 = plt.subplots(figsize=(6, 6))
im2 = ax2.imshow(sensitivity_map, cmap='jet')
ax2.set_title('Sensitivity Map', fontsize=12, fontweight='bold')
ax2.axis('off')
plt.colorbar(im2, ax=ax2, fraction=0.046, pad=0.04)
plt.tight_layout()
buf2 = io.BytesIO()
plt.savefig(buf2, format='png', dpi=150, bbox_inches='tight')
buf2.seek(0)
sensitivity_img = Image.open(buf2)
plt.close()
# Overlay
fig3, ax3 = plt.subplots(figsize=(6, 6))
ax3.imshow(image, cmap='gray')
custom_cmap = create_custom_colormap()
masked_sensitivity = np.ma.masked_where(sensitivity_map < 0.15, sensitivity_map)
im3 = ax3.imshow(masked_sensitivity, cmap=custom_cmap, alpha=0.6)
ax3.set_title('Overlay (Red = High Importance)', fontsize=12, fontweight='bold')
ax3.axis('off')
plt.colorbar(im3, ax=ax3, fraction=0.046, pad=0.04)
plt.tight_layout()
buf3 = io.BytesIO()
plt.savefig(buf3, format='png', dpi=150, bbox_inches='tight')
buf3.seek(0)
overlay_img = Image.open(buf3)
plt.close()
return original_img, sensitivity_img, overlay_img
def predict_and_visualize(image, patch_size=25, stride=5, progress=gr.Progress()):
if image is None:
return None, None, None, "⚠️ Please select or upload an image"
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image = image.convert('L')
image = image.resize((256, 256))
img_array = img_to_array(image) / 255.0
# Initial prediction
predictions = model.predict(np.expand_dims(img_array, axis=0), verbose=0)[0]
predicted_class = np.argmax(predictions)
confidence = predictions[predicted_class]
pred_text = f"## 🎯 Prediction Results\n\n"
pred_text += f"### **Predicted Class:** {class_names[predicted_class].upper()}\n"
pred_text += f"### **Confidence:** {confidence:.2%}\n\n"
pred_text += "---\n\n"
pred_text += "### πŸ“Š Class Probabilities:\n\n"
# Sort predictions by probability
sorted_indices = np.argsort(predictions)[::-1]
for i in sorted_indices:
bar = "β–ˆ" * int(predictions[i] * 20)
pred_text += f"**{class_names[i].capitalize()}:** {predictions[i]:.2%} {bar}\n\n"
# Compute sensitivity map with progress tracking
sensitivity_map, _ = occlusion_sensitivity(
model, img_array, predicted_class,
patch_size=patch_size, stride=stride,
progress=progress
)
# Normalize sensitivity map
eps = 1e-9
sensitivity_map = (sensitivity_map - np.min(sensitivity_map)) / (np.max(sensitivity_map) - np.min(sensitivity_map) + eps)
# Create visualizations
original_img, sensitivity_img, overlay_img = create_sensitivity_visualizations(
image, sensitivity_map, predicted_class, confidence
)
progress(1.0, desc="Complete!")
return original_img, sensitivity_img, overlay_img, pred_text
def get_example_images(examples_folder='examples/'):
"""Get list of example images from the Examples folder"""
if not os.path.exists(examples_folder):
return []
image_paths = []
for ext in ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']:
image_paths.extend(glob.glob(os.path.join(examples_folder, ext)))
return sorted(image_paths)
def load_example_image(image_path):
"""Load an image from the examples folder"""
if image_path:
return Image.open(image_path)
return None
with gr.Blocks(title="Brain MRI Tumor Classifier", theme=gr.themes.Default()) as demo:
gr.Markdown("""
# 🧠 Brain MRI Tumor Classifier with Occlusion Sensitivity
This application classifies brain MRI images into four categories and visualizes which regions
are most important for the classification decision.
**Classes:** Glioma β€’ Meningioma β€’ No Tumor β€’ Pituitary
""")
with gr.Row():
# Left column - Input controls
with gr.Column(scale=1):
gr.Markdown("### πŸ“ Select or Upload Image")
# Example image selector
example_images = get_example_images('examples/')
if example_images:
example_dropdown = gr.Dropdown(
choices=example_images,
label="Select from Examples folder",
value=None,
interactive=True
)
else:
example_dropdown = gr.Dropdown(
choices=[],
label="No examples found in 'Examples' folder",
interactive=False
)
# OR upload custom image
gr.Markdown("**OR**")
input_image = gr.Image(type="pil", label="Upload Your Own MRI Image")
gr.Markdown("### βš™οΈ Sensitivity Analysis Settings")
patch_size_slider = gr.Slider(
minimum=10, maximum=50, value=25, step=5,
label="Patch Size",
info="Larger = faster but less detailed"
)
stride_slider = gr.Slider(
minimum=5, maximum=20, value=5, step=5,
label="Stride",
info="Larger = faster but less detailed"
)
predict_btn = gr.Button("πŸ” Analyze Image", variant="primary", size="lg")
gr.Markdown("""
---
**ℹ️ Tip:** Select an image from the dropdown or upload your own,
then click 'Analyze Image' to see predictions and sensitivity maps.
""")
# Right column - Results
with gr.Column(scale=2):
# Prediction results at the top
output_text = gr.Markdown("### Waiting for analysis...")
gr.Markdown("---")
# Visualization results
gr.Markdown("### πŸ”¬ Sensitivity Analysis Visualizations")
with gr.Row():
output_original = gr.Image(label="Original Image", type="pil")
output_sensitivity = gr.Image(label="Sensitivity Map", type="pil")
with gr.Row():
output_overlay = gr.Image(label="Overlay Visualization", type="pil")
# Event handlers
if example_images:
example_dropdown.change(
fn=load_example_image,
inputs=[example_dropdown],
outputs=[input_image]
)
predict_btn.click(
fn=predict_and_visualize,
inputs=[input_image, patch_size_slider, stride_slider],
outputs=[output_original, output_sensitivity, output_overlay, output_text],
api_name="predict",
)
if __name__ == "__main__":
demo.launch(share=False)