Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import keras | |
| from matplotlib.colors import LinearSegmentedColormap | |
| from tensorflow.keras.preprocessing.image import img_to_array | |
| from PIL import Image | |
| import io | |
| import glob | |
| import os | |
| import pickle | |
| from tqdm import tqdm | |
| import numpy as np | |
| with open('model.pkl', 'rb') as f: | |
| model = pickle.load(f) | |
| class_names = ['glioma', 'meningioma', 'notumor', 'pituitary'] | |
| def create_custom_colormap(): | |
| colors = [(0, 0, 0, 0), (1, 0, 0, 1)] | |
| return LinearSegmentedColormap.from_list('custom', colors) | |
| def create_custom_colormap(): | |
| colors = [(0, 0, 0, 0), (1, 0, 0, 1)] | |
| return LinearSegmentedColormap.from_list('custom', colors) | |
| def occlusion_sensitivity(model, img_array, class_index, patch_size=25, stride=5, occlusion_value=0, progress=gr.Progress()): | |
| sensitivity_map = np.zeros((img_array.shape[0], img_array.shape[1])) | |
| original_pred = model.predict(np.expand_dims(img_array, axis=0), verbose=0)[0] | |
| original_prob = original_pred[class_index] | |
| n_steps_h = (img_array.shape[0] - patch_size) // stride + 1 | |
| n_steps_w = (img_array.shape[1] - patch_size) // stride + 1 | |
| total_steps = n_steps_h * n_steps_w | |
| current_step = 0 | |
| for h in range(n_steps_h): | |
| for w in range(n_steps_w): | |
| h_start = h * stride | |
| w_start = w * stride | |
| h_end = min(h_start + patch_size, img_array.shape[0]) | |
| w_end = min(w_start + patch_size, img_array.shape[1]) | |
| occluded_img = img_array.copy() | |
| occluded_img[h_start:h_end, w_start:w_end, :] = occlusion_value | |
| pred = model.predict(np.expand_dims(occluded_img, axis=0), verbose=0)[0] | |
| prob = pred[class_index] | |
| sensitivity = original_prob - prob | |
| sensitivity_map[h_start:h_end, w_start:w_end] += sensitivity | |
| current_step += 1 | |
| progress(current_step / total_steps, desc=f"Analyzing sensitivity: {current_step}/{total_steps}") | |
| return sensitivity_map, original_prob | |
| def create_sensitivity_visualizations(image, sensitivity_map, predicted_class, confidence): | |
| # Original image | |
| fig1, ax1 = plt.subplots(figsize=(6, 6)) | |
| ax1.imshow(image, cmap='gray') | |
| ax1.set_title(f'Original Image\n{class_names[predicted_class]} ({confidence:.1%})', fontsize=12, fontweight='bold') | |
| ax1.axis('off') | |
| plt.tight_layout() | |
| buf1 = io.BytesIO() | |
| plt.savefig(buf1, format='png', dpi=150, bbox_inches='tight') | |
| buf1.seek(0) | |
| original_img = Image.open(buf1) | |
| plt.close() | |
| # Sensitivity map | |
| fig2, ax2 = plt.subplots(figsize=(6, 6)) | |
| im2 = ax2.imshow(sensitivity_map, cmap='jet') | |
| ax2.set_title('Sensitivity Map', fontsize=12, fontweight='bold') | |
| ax2.axis('off') | |
| plt.colorbar(im2, ax=ax2, fraction=0.046, pad=0.04) | |
| plt.tight_layout() | |
| buf2 = io.BytesIO() | |
| plt.savefig(buf2, format='png', dpi=150, bbox_inches='tight') | |
| buf2.seek(0) | |
| sensitivity_img = Image.open(buf2) | |
| plt.close() | |
| # Overlay | |
| fig3, ax3 = plt.subplots(figsize=(6, 6)) | |
| ax3.imshow(image, cmap='gray') | |
| custom_cmap = create_custom_colormap() | |
| masked_sensitivity = np.ma.masked_where(sensitivity_map < 0.15, sensitivity_map) | |
| im3 = ax3.imshow(masked_sensitivity, cmap=custom_cmap, alpha=0.6) | |
| ax3.set_title('Overlay (Red = High Importance)', fontsize=12, fontweight='bold') | |
| ax3.axis('off') | |
| plt.colorbar(im3, ax=ax3, fraction=0.046, pad=0.04) | |
| plt.tight_layout() | |
| buf3 = io.BytesIO() | |
| plt.savefig(buf3, format='png', dpi=150, bbox_inches='tight') | |
| buf3.seek(0) | |
| overlay_img = Image.open(buf3) | |
| plt.close() | |
| return original_img, sensitivity_img, overlay_img | |
| def predict_and_visualize(image, patch_size=25, stride=5, progress=gr.Progress()): | |
| if image is None: | |
| return None, None, None, "β οΈ Please select or upload an image" | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| image = image.convert('L') | |
| image = image.resize((256, 256)) | |
| img_array = img_to_array(image) / 255.0 | |
| # Initial prediction | |
| predictions = model.predict(np.expand_dims(img_array, axis=0), verbose=0)[0] | |
| predicted_class = np.argmax(predictions) | |
| confidence = predictions[predicted_class] | |
| pred_text = f"## π― Prediction Results\n\n" | |
| pred_text += f"### **Predicted Class:** {class_names[predicted_class].upper()}\n" | |
| pred_text += f"### **Confidence:** {confidence:.2%}\n\n" | |
| pred_text += "---\n\n" | |
| pred_text += "### π Class Probabilities:\n\n" | |
| # Sort predictions by probability | |
| sorted_indices = np.argsort(predictions)[::-1] | |
| for i in sorted_indices: | |
| bar = "β" * int(predictions[i] * 20) | |
| pred_text += f"**{class_names[i].capitalize()}:** {predictions[i]:.2%} {bar}\n\n" | |
| # Compute sensitivity map with progress tracking | |
| sensitivity_map, _ = occlusion_sensitivity( | |
| model, img_array, predicted_class, | |
| patch_size=patch_size, stride=stride, | |
| progress=progress | |
| ) | |
| # Normalize sensitivity map | |
| eps = 1e-9 | |
| sensitivity_map = (sensitivity_map - np.min(sensitivity_map)) / (np.max(sensitivity_map) - np.min(sensitivity_map) + eps) | |
| # Create visualizations | |
| original_img, sensitivity_img, overlay_img = create_sensitivity_visualizations( | |
| image, sensitivity_map, predicted_class, confidence | |
| ) | |
| progress(1.0, desc="Complete!") | |
| return original_img, sensitivity_img, overlay_img, pred_text | |
| def get_example_images(examples_folder='examples/'): | |
| """Get list of example images from the Examples folder""" | |
| if not os.path.exists(examples_folder): | |
| return [] | |
| image_paths = [] | |
| for ext in ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']: | |
| image_paths.extend(glob.glob(os.path.join(examples_folder, ext))) | |
| return sorted(image_paths) | |
| def load_example_image(image_path): | |
| """Load an image from the examples folder""" | |
| if image_path: | |
| return Image.open(image_path) | |
| return None | |
| with gr.Blocks(title="Brain MRI Tumor Classifier", theme=gr.themes.Default()) as demo: | |
| gr.Markdown(""" | |
| # π§ Brain MRI Tumor Classifier with Occlusion Sensitivity | |
| This application classifies brain MRI images into four categories and visualizes which regions | |
| are most important for the classification decision. | |
| **Classes:** Glioma β’ Meningioma β’ No Tumor β’ Pituitary | |
| """) | |
| with gr.Row(): | |
| # Left column - Input controls | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Select or Upload Image") | |
| # Example image selector | |
| example_images = get_example_images('examples/') | |
| if example_images: | |
| example_dropdown = gr.Dropdown( | |
| choices=example_images, | |
| label="Select from Examples folder", | |
| value=None, | |
| interactive=True | |
| ) | |
| else: | |
| example_dropdown = gr.Dropdown( | |
| choices=[], | |
| label="No examples found in 'Examples' folder", | |
| interactive=False | |
| ) | |
| # OR upload custom image | |
| gr.Markdown("**OR**") | |
| input_image = gr.Image(type="pil", label="Upload Your Own MRI Image") | |
| gr.Markdown("### βοΈ Sensitivity Analysis Settings") | |
| patch_size_slider = gr.Slider( | |
| minimum=10, maximum=50, value=50, step=5, | |
| label="Patch Size", | |
| info="Larger = faster but less detailed" | |
| ) | |
| stride_slider = gr.Slider( | |
| minimum=5, maximum=20, value=20, step=5, | |
| label="Stride", | |
| info="Larger = faster but less detailed" | |
| ) | |
| predict_btn = gr.Button("π Analyze Image", variant="primary", size="lg") | |
| gr.Markdown(""" | |
| --- | |
| **βΉοΈ Tip:** Select an image from the dropdown or upload your own, | |
| then click 'Analyze Image' to see predictions and sensitivity maps. | |
| --- | |
| **π Note:** The default values for patch size and stride have been set to the maximum due to computational constraints. | |
| Though the results will be enough for a quick demo. | |
| """) | |
| # Right column - Results | |
| with gr.Column(scale=2): | |
| # Prediction results at the top | |
| output_text = gr.Markdown("### Waiting for analysis...") | |
| gr.Markdown("---") | |
| # Visualization results | |
| gr.Markdown("### π¬ Sensitivity Analysis Visualizations") | |
| with gr.Row(): | |
| output_original = gr.Image(label="Original Image", type="pil") | |
| output_sensitivity = gr.Image(label="Sensitivity Map", type="pil") | |
| with gr.Row(): | |
| output_overlay = gr.Image(label="Overlay Visualization", type="pil") | |
| # Event handlers | |
| if example_images: | |
| example_dropdown.change( | |
| fn=load_example_image, | |
| inputs=[example_dropdown], | |
| outputs=[input_image] | |
| ) | |
| predict_btn.click( | |
| fn=predict_and_visualize, | |
| inputs=[input_image, patch_size_slider, stride_slider], | |
| outputs=[output_original, output_sensitivity, output_overlay, output_text], | |
| api_name="predict", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=False) |