processoptimisationsystem commited on
Commit
69830d3
·
verified ·
1 Parent(s): 090defa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -0
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.transforms as transforms
5
+ from PIL import Image as Img
6
+ import numpy as np
7
+ import cv2
8
+ import matplotlib.pyplot as plt
9
+ from pytorch_grad_cam import GradCAM
10
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
11
+ from pytorch_grad_cam.utils.image import show_cam_on_image
12
+ from lime.lime_image import LimeImageExplainer
13
+ from skimage.segmentation import mark_boundaries
14
+ import shap
15
+ from shap import GradientExplainer
16
+ import gradio as gr
17
+
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ num_classes = 4
20
+ image_size = (224, 224)
21
+
22
+ # Define CNN Model
23
+ class MyModel(nn.Module):
24
+ def __init__(self, num_classes=4):
25
+ super(MyModel, self).__init__()
26
+ self.features = nn.Sequential(
27
+ nn.Conv2d(3, 64, kernel_size=3, padding=1),
28
+ nn.BatchNorm2d(64),
29
+ nn.ReLU(inplace=True),
30
+ nn.MaxPool2d(kernel_size=2, stride=2),
31
+
32
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
33
+ nn.BatchNorm2d(128),
34
+ nn.ReLU(inplace=True),
35
+ nn.MaxPool2d(kernel_size=2, stride=2),
36
+
37
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
38
+ nn.BatchNorm2d(128),
39
+ nn.ReLU(inplace=True),
40
+ nn.MaxPool2d(kernel_size=2, stride=2),
41
+
42
+ nn.Conv2d(128, 256, kernel_size=3, padding=1),
43
+ nn.BatchNorm2d(256),
44
+ nn.ReLU(inplace=True),
45
+ nn.MaxPool2d(kernel_size=2, stride=2),
46
+
47
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
48
+ nn.BatchNorm2d(256),
49
+ nn.ReLU(inplace=True),
50
+ nn.MaxPool2d(kernel_size=2, stride=2),
51
+
52
+ nn.Conv2d(256, 512, kernel_size=3, padding=1),
53
+ nn.BatchNorm2d(512),
54
+ nn.ReLU(inplace=True),
55
+ nn.MaxPool2d(kernel_size=2, stride=2),
56
+ )
57
+ self.classifier = nn.Sequential(
58
+ nn.Flatten(),
59
+ nn.Linear(512 * 3 * 3, 1024),
60
+ nn.ReLU(inplace=True),
61
+ nn.Dropout(0.25),
62
+
63
+ nn.Linear(1024, 512),
64
+ nn.ReLU(inplace=True),
65
+ nn.Dropout(0.25),
66
+
67
+ nn.Linear(512, num_classes)
68
+ )
69
+ def forward(self, x):
70
+ x = self.features(x)
71
+ x = self.classifier(x)
72
+ return x
73
+
74
+ # Load model
75
+ model = MyModel(num_classes=num_classes).to(device)
76
+ model.load_state_dict(torch.load("brainCNNpytorch_model", map_location=torch.device('cpu')))
77
+ model.eval()
78
+
79
+ label_dict = {0: "Meningioma", 1: "Glioma", 2: "No Tumor", 3: "Pituitary"}
80
+
81
+ def preprocess_image(image):
82
+ transform = transforms.Compose([
83
+ transforms.Resize((224, 224)),
84
+ transforms.ToTensor(),
85
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
86
+ ])
87
+ return transform(image).unsqueeze(0).to(device)
88
+
89
+ def visualize_grad_cam(image, model, target_layer, label):
90
+ img_np = np.array(image) / 255.0
91
+ img_np = cv2.resize(img_np, (224, 224))
92
+ img_tensor = preprocess_image(image)
93
+ with torch.no_grad():
94
+ output = model(img_tensor)
95
+ _, target_index = torch.max(output, 1)
96
+ cam = GradCAM(model=model, target_layers=[target_layer])
97
+ grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(target_index.item())])[0]
98
+ grayscale_cam_resized = cv2.resize(grayscale_cam, (224, 224))
99
+ visualization = show_cam_on_image(img_np, grayscale_cam_resized, use_rgb=True)
100
+ return visualization
101
+
102
+ def model_predict(images):
103
+ preprocessed_images = [preprocess_image(Img.fromarray(img)) for img in images]
104
+ images_tensor = torch.cat(preprocessed_images).to(device)
105
+ with torch.no_grad():
106
+ logits = model(images_tensor)
107
+ probabilities = F.softmax(logits, dim=1)
108
+ return probabilities.cpu().numpy()
109
+
110
+ def visualize_lime(image):
111
+ explainer = LimeImageExplainer()
112
+ original_image = np.array(image)
113
+ explanation = explainer.explain_instance(original_image, model_predict, top_labels=3, hide_color=0, num_samples=100)
114
+ top_label = explanation.top_labels[0]
115
+ temp, mask = explanation.get_image_and_mask(label=top_label, positive_only=True, num_features=10, hide_rest=False)
116
+ return mark_boundaries(temp / 255.0, mask)
117
+
118
+ def visualize_shap(image):
119
+ img_tensor = preprocess_image(image).to(device)
120
+ if img_tensor.shape[1] == 1:
121
+ img_tensor = img_tensor.expand(-1, 3, -1, -1)
122
+ background = torch.cat([img_tensor] * 10, dim=0)
123
+ explainer = shap.GradientExplainer(model, background)
124
+ shap_values = explainer.shap_values(img_tensor)
125
+ img_numpy = img_tensor.squeeze().permute(1, 2, 0).cpu().numpy()
126
+ shap_values = np.array(shap_values[0]).squeeze()
127
+ shap_values = shap_values / np.abs(shap_values).max() if np.abs(shap_values).max() != 0 else shap_values
128
+ shap_values = np.transpose(shap_values, (1, 2, 0))
129
+ fig, ax = plt.subplots(figsize=(5, 5))
130
+ ax.imshow(img_numpy)
131
+ ax.imshow(shap_values, cmap='jet', alpha=0.5)
132
+ ax.axis('off')
133
+ plt.tight_layout()
134
+ return fig
135
+
136
+ def classify_and_visualize(image):
137
+ image = Img.fromarray(image).convert("RGB")
138
+ image_tensor = preprocess_image(image)
139
+ with torch.no_grad():
140
+ output = model(image_tensor)
141
+ _, predicted = torch.max(output, 1)
142
+ label = label_dict[predicted.item()]
143
+ # Grad-CAM
144
+ target_layer = model.features[16] # Last Conv layer
145
+ grad_cam_img = visualize_grad_cam(image, model, target_layer, label)
146
+ # LIME
147
+ lime_img = visualize_lime(image)
148
+ # SHAP
149
+ shap_fig = visualize_shap(image)
150
+
151
+ return label, grad_cam_img, lime_img, shap_fig
152
+
153
+ # Create Gradio interface
154
+ title = "Brain Tumor Classification with Grad-CAM, LIME, and SHAP"
155
+
156
+ inputs = gr.Image(type="numpy", label="Upload an MRI Image")
157
+ outputs = [
158
+ gr.Textbox(label="Prediction"),
159
+ gr.Image(type="numpy", label="Grad-CAM"),
160
+ gr.Image(type="numpy", label="LIME Explanation"),
161
+ gr.Plot(label="SHAP Explanation")
162
+ ]
163
+
164
+ iface = gr.Interface(fn=classify_and_visualize, inputs=inputs, outputs=outputs, title=title)
165
+ iface.launch()