File size: 6,031 Bytes
69830d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image as Img
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from lime.lime_image import LimeImageExplainer
from skimage.segmentation import mark_boundaries
import shap
from shap import GradientExplainer
import gradio as gr
device = "cuda" if torch.cuda.is_available() else "cpu"
num_classes = 4
image_size = (224, 224)
# Define CNN Model
class MyModel(nn.Module):
def __init__(self, num_classes=4):
super(MyModel, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(512 * 3 * 3, 1024),
nn.ReLU(inplace=True),
nn.Dropout(0.25),
nn.Linear(1024, 512),
nn.ReLU(inplace=True),
nn.Dropout(0.25),
nn.Linear(512, num_classes)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
# Load model
model = MyModel(num_classes=num_classes).to(device)
model.load_state_dict(torch.load("brainCNNpytorch_model", map_location=torch.device('cpu')))
model.eval()
label_dict = {0: "Meningioma", 1: "Glioma", 2: "No Tumor", 3: "Pituitary"}
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return transform(image).unsqueeze(0).to(device)
def visualize_grad_cam(image, model, target_layer, label):
img_np = np.array(image) / 255.0
img_np = cv2.resize(img_np, (224, 224))
img_tensor = preprocess_image(image)
with torch.no_grad():
output = model(img_tensor)
_, target_index = torch.max(output, 1)
cam = GradCAM(model=model, target_layers=[target_layer])
grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(target_index.item())])[0]
grayscale_cam_resized = cv2.resize(grayscale_cam, (224, 224))
visualization = show_cam_on_image(img_np, grayscale_cam_resized, use_rgb=True)
return visualization
def model_predict(images):
preprocessed_images = [preprocess_image(Img.fromarray(img)) for img in images]
images_tensor = torch.cat(preprocessed_images).to(device)
with torch.no_grad():
logits = model(images_tensor)
probabilities = F.softmax(logits, dim=1)
return probabilities.cpu().numpy()
def visualize_lime(image):
explainer = LimeImageExplainer()
original_image = np.array(image)
explanation = explainer.explain_instance(original_image, model_predict, top_labels=3, hide_color=0, num_samples=100)
top_label = explanation.top_labels[0]
temp, mask = explanation.get_image_and_mask(label=top_label, positive_only=True, num_features=10, hide_rest=False)
return mark_boundaries(temp / 255.0, mask)
def visualize_shap(image):
img_tensor = preprocess_image(image).to(device)
if img_tensor.shape[1] == 1:
img_tensor = img_tensor.expand(-1, 3, -1, -1)
background = torch.cat([img_tensor] * 10, dim=0)
explainer = shap.GradientExplainer(model, background)
shap_values = explainer.shap_values(img_tensor)
img_numpy = img_tensor.squeeze().permute(1, 2, 0).cpu().numpy()
shap_values = np.array(shap_values[0]).squeeze()
shap_values = shap_values / np.abs(shap_values).max() if np.abs(shap_values).max() != 0 else shap_values
shap_values = np.transpose(shap_values, (1, 2, 0))
fig, ax = plt.subplots(figsize=(5, 5))
ax.imshow(img_numpy)
ax.imshow(shap_values, cmap='jet', alpha=0.5)
ax.axis('off')
plt.tight_layout()
return fig
def classify_and_visualize(image):
image = Img.fromarray(image).convert("RGB")
image_tensor = preprocess_image(image)
with torch.no_grad():
output = model(image_tensor)
_, predicted = torch.max(output, 1)
label = label_dict[predicted.item()]
# Grad-CAM
target_layer = model.features[16] # Last Conv layer
grad_cam_img = visualize_grad_cam(image, model, target_layer, label)
# LIME
lime_img = visualize_lime(image)
# SHAP
shap_fig = visualize_shap(image)
return label, grad_cam_img, lime_img, shap_fig
# Create Gradio interface
title = "Brain Tumor Classification with Grad-CAM, LIME, and SHAP"
inputs = gr.Image(type="numpy", label="Upload an MRI Image")
outputs = [
gr.Textbox(label="Prediction"),
gr.Image(type="numpy", label="Grad-CAM"),
gr.Image(type="numpy", label="LIME Explanation"),
gr.Plot(label="SHAP Explanation")
]
iface = gr.Interface(fn=classify_and_visualize, inputs=inputs, outputs=outputs, title=title)
iface.launch()
|