import gradio as gr import matplotlib.pyplot as plt import keras from matplotlib.colors import LinearSegmentedColormap from tensorflow.keras.preprocessing.image import img_to_array from PIL import Image import io import glob import os import pickle from tqdm import tqdm import numpy as np with open('model.pkl', 'rb') as f: model = pickle.load(f) class_names = ['glioma', 'meningioma', 'notumor', 'pituitary'] def create_custom_colormap(): colors = [(0, 0, 0, 0), (1, 0, 0, 1)] return LinearSegmentedColormap.from_list('custom', colors) def create_custom_colormap(): colors = [(0, 0, 0, 0), (1, 0, 0, 1)] return LinearSegmentedColormap.from_list('custom', colors) def occlusion_sensitivity(model, img_array, class_index, patch_size=25, stride=5, occlusion_value=0, progress=gr.Progress()): sensitivity_map = np.zeros((img_array.shape[0], img_array.shape[1])) original_pred = model.predict(np.expand_dims(img_array, axis=0), verbose=0)[0] original_prob = original_pred[class_index] n_steps_h = (img_array.shape[0] - patch_size) // stride + 1 n_steps_w = (img_array.shape[1] - patch_size) // stride + 1 total_steps = n_steps_h * n_steps_w current_step = 0 for h in range(n_steps_h): for w in range(n_steps_w): h_start = h * stride w_start = w * stride h_end = min(h_start + patch_size, img_array.shape[0]) w_end = min(w_start + patch_size, img_array.shape[1]) occluded_img = img_array.copy() occluded_img[h_start:h_end, w_start:w_end, :] = occlusion_value pred = model.predict(np.expand_dims(occluded_img, axis=0), verbose=0)[0] prob = pred[class_index] sensitivity = original_prob - prob sensitivity_map[h_start:h_end, w_start:w_end] += sensitivity current_step += 1 progress(current_step / total_steps, desc=f"Analyzing sensitivity: {current_step}/{total_steps}") return sensitivity_map, original_prob def create_sensitivity_visualizations(image, sensitivity_map, predicted_class, confidence): # Original image fig1, ax1 = plt.subplots(figsize=(6, 6)) ax1.imshow(image, cmap='gray') ax1.set_title(f'Original Image\n{class_names[predicted_class]} ({confidence:.1%})', fontsize=12, fontweight='bold') ax1.axis('off') plt.tight_layout() buf1 = io.BytesIO() plt.savefig(buf1, format='png', dpi=150, bbox_inches='tight') buf1.seek(0) original_img = Image.open(buf1) plt.close() # Sensitivity map fig2, ax2 = plt.subplots(figsize=(6, 6)) im2 = ax2.imshow(sensitivity_map, cmap='jet') ax2.set_title('Sensitivity Map', fontsize=12, fontweight='bold') ax2.axis('off') plt.colorbar(im2, ax=ax2, fraction=0.046, pad=0.04) plt.tight_layout() buf2 = io.BytesIO() plt.savefig(buf2, format='png', dpi=150, bbox_inches='tight') buf2.seek(0) sensitivity_img = Image.open(buf2) plt.close() # Overlay fig3, ax3 = plt.subplots(figsize=(6, 6)) ax3.imshow(image, cmap='gray') custom_cmap = create_custom_colormap() masked_sensitivity = np.ma.masked_where(sensitivity_map < 0.15, sensitivity_map) im3 = ax3.imshow(masked_sensitivity, cmap=custom_cmap, alpha=0.6) ax3.set_title('Overlay (Red = High Importance)', fontsize=12, fontweight='bold') ax3.axis('off') plt.colorbar(im3, ax=ax3, fraction=0.046, pad=0.04) plt.tight_layout() buf3 = io.BytesIO() plt.savefig(buf3, format='png', dpi=150, bbox_inches='tight') buf3.seek(0) overlay_img = Image.open(buf3) plt.close() return original_img, sensitivity_img, overlay_img def predict_and_visualize(image, patch_size=25, stride=5, progress=gr.Progress()): if image is None: return None, None, None, "âš ī¸ Please select or upload an image" if not isinstance(image, Image.Image): image = Image.fromarray(image) image = image.convert('L') image = image.resize((256, 256)) img_array = img_to_array(image) / 255.0 # Initial prediction predictions = model.predict(np.expand_dims(img_array, axis=0), verbose=0)[0] predicted_class = np.argmax(predictions) confidence = predictions[predicted_class] pred_text = f"## đŸŽ¯ Prediction Results\n\n" pred_text += f"### **Predicted Class:** {class_names[predicted_class].upper()}\n" pred_text += f"### **Confidence:** {confidence:.2%}\n\n" pred_text += "---\n\n" pred_text += "### 📊 Class Probabilities:\n\n" # Sort predictions by probability sorted_indices = np.argsort(predictions)[::-1] for i in sorted_indices: bar = "█" * int(predictions[i] * 20) pred_text += f"**{class_names[i].capitalize()}:** {predictions[i]:.2%} {bar}\n\n" # Compute sensitivity map with progress tracking sensitivity_map, _ = occlusion_sensitivity( model, img_array, predicted_class, patch_size=patch_size, stride=stride, progress=progress ) # Normalize sensitivity map eps = 1e-9 sensitivity_map = (sensitivity_map - np.min(sensitivity_map)) / (np.max(sensitivity_map) - np.min(sensitivity_map) + eps) # Create visualizations original_img, sensitivity_img, overlay_img = create_sensitivity_visualizations( image, sensitivity_map, predicted_class, confidence ) progress(1.0, desc="Complete!") return original_img, sensitivity_img, overlay_img, pred_text def get_example_images(examples_folder='examples/'): """Get list of example images from the Examples folder""" if not os.path.exists(examples_folder): return [] image_paths = [] for ext in ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']: image_paths.extend(glob.glob(os.path.join(examples_folder, ext))) return sorted(image_paths) def load_example_image(image_path): """Load an image from the examples folder""" if image_path: return Image.open(image_path) return None with gr.Blocks(title="Brain MRI Tumor Classifier", theme=gr.themes.Default()) as demo: gr.Markdown(""" # 🧠 Brain MRI Tumor Classifier with Occlusion Sensitivity This application classifies brain MRI images into four categories and visualizes which regions are most important for the classification decision. **Classes:** Glioma â€ĸ Meningioma â€ĸ No Tumor â€ĸ Pituitary """) with gr.Row(): # Left column - Input controls with gr.Column(scale=1): gr.Markdown("### 📁 Select or Upload Image") # Example image selector example_images = get_example_images('examples/') if example_images: example_dropdown = gr.Dropdown( choices=example_images, label="Select from Examples folder", value=None, interactive=True ) else: example_dropdown = gr.Dropdown( choices=[], label="No examples found in 'Examples' folder", interactive=False ) # OR upload custom image gr.Markdown("**OR**") input_image = gr.Image(type="pil", label="Upload Your Own MRI Image") gr.Markdown("### âš™ī¸ Sensitivity Analysis Settings") patch_size_slider = gr.Slider( minimum=10, maximum=50, value=50, step=5, label="Patch Size", info="Larger = faster but less detailed" ) stride_slider = gr.Slider( minimum=5, maximum=20, value=20, step=5, label="Stride", info="Larger = faster but less detailed" ) predict_btn = gr.Button("🔍 Analyze Image", variant="primary", size="lg") gr.Markdown(""" --- **â„šī¸ Tip:** Select an image from the dropdown or upload your own, then click 'Analyze Image' to see predictions and sensitivity maps. --- **📝 Note:** The default values for patch size and stride have been set to the maximum due to computational constraints. Though the results will be enough for a quick demo. """) # Right column - Results with gr.Column(scale=2): # Prediction results at the top output_text = gr.Markdown("### Waiting for analysis...") gr.Markdown("---") # Visualization results gr.Markdown("### đŸ”Ŧ Sensitivity Analysis Visualizations") with gr.Row(): output_original = gr.Image(label="Original Image", type="pil") output_sensitivity = gr.Image(label="Sensitivity Map", type="pil") with gr.Row(): output_overlay = gr.Image(label="Overlay Visualization", type="pil") # Event handlers if example_images: example_dropdown.change( fn=load_example_image, inputs=[example_dropdown], outputs=[input_image] ) predict_btn.click( fn=predict_and_visualize, inputs=[input_image, patch_size_slider, stride_slider], outputs=[output_original, output_sensitivity, output_overlay, output_text], api_name="predict", ) if __name__ == "__main__": demo.launch(share=False)