gmustafa413 commited on
Commit
8a3ec9d
·
verified ·
1 Parent(s): 085ed0d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +278 -0
app.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import numpy as np
3
+ import os, shutil
4
+ import matplotlib.pyplot as plt
5
+
6
+ from PIL import Image
7
+
8
+ from tqdm.auto import tqdm
9
+
10
+ import torch
11
+ import torchvision
12
+ from torch.utils.data import DataLoader
13
+ from torchvision.datasets import ImageFolder
14
+ from torchvision.transforms import transforms
15
+ import torch.optim as optim
16
+
17
+ from torchvision.models import resnet50, ResNet50_Weights
18
+
19
+
20
+ transform = transforms.Compose([
21
+ transforms.Resize((224,224)),
22
+ transforms.ToTensor()
23
+ ])
24
+
25
+ import urllib.request
26
+ urllib.request.urlretrieve("https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420937484-1629951672/carpet.tar.xz",
27
+ "carpet.tar.xz")
28
+
29
+ import tarfile
30
+
31
+ with tarfile.open('carpet.tar.xz') as f:
32
+ f.extractall('.')
33
+
34
+
35
+ class resnet_feature_extractor(torch.nn.Module):
36
+ def __init__(self):
37
+ """This class extracts the feature maps from a pretrained Resnet model."""
38
+ super(resnet_feature_extractor, self).__init__()
39
+ self.model = resnet50(weights=ResNet50_Weights.DEFAULT)
40
+
41
+ self.model.eval()
42
+ for param in self.model.parameters():
43
+ param.requires_grad = False
44
+
45
+
46
+
47
+ # Hook to extract feature maps
48
+ def hook(module, input, output) -> None:
49
+ """This hook saves the extracted feature map on self.featured."""
50
+ self.features.append(output)
51
+
52
+ self.model.layer2[-1].register_forward_hook(hook)
53
+ self.model.layer3[-1].register_forward_hook(hook)
54
+
55
+ def forward(self, input):
56
+
57
+ self.features = []
58
+ with torch.no_grad():
59
+ _ = self.model(input)
60
+
61
+ self.avg = torch.nn.AvgPool2d(3, stride=1)
62
+ fmap_size = self.features[0].shape[-2] # Feature map sizes h, w
63
+ self.resize = torch.nn.AdaptiveAvgPool2d(fmap_size)
64
+
65
+ resized_maps = [self.resize(self.avg(fmap)) for fmap in self.features]
66
+ patch = torch.cat(resized_maps, 1) # Merge the resized feature maps
67
+ patch = patch.reshape(patch.shape[1], -1).T # Craete a column tensor
68
+
69
+ return patch
70
+
71
+
72
+
73
+ image = Image.open(r'/content/carpet/test/color/000.png')
74
+ image = transform(image).unsqueeze(0)
75
+
76
+ backbone = resnet_feature_extractor()
77
+ feature = backbone(image)
78
+
79
+ # print(backbone.features[0].shape)
80
+ # print(backbone.features[1].shape)
81
+
82
+ print(feature.shape)
83
+
84
+ # plt.imshow(image[0].permute(1,2,0))
85
+
86
+ memory_bank =[]
87
+
88
+ folder_path = Path(r'/content/carpet/train/good')
89
+
90
+ for pth in tqdm(folder_path.iterdir(),leave=False):
91
+ with torch.no_grad():
92
+ data = transform(Image.open(pth)).unsqueeze(0)
93
+ features = backbone(data)
94
+ memory_bank.append(features.cpu().detach())
95
+
96
+ memory_bank = torch.cat(memory_bank,dim=0)
97
+
98
+ y_score=[]
99
+
100
+ folder_path = Path(r'/content/carpet/train/good')
101
+
102
+ for pth in tqdm(folder_path.iterdir(),leave=False):
103
+ data = transform(Image.open(pth)).unsqueeze(0)
104
+ with torch.no_grad():
105
+ features = backbone(data)
106
+ distances = torch.cdist(features, memory_bank, p=2.0)
107
+ dist_score, dist_score_idxs = torch.min(distances, dim=1)
108
+ s_star = torch.max(dist_score)
109
+ segm_map = dist_score.view(1, 1, 28, 28)
110
+
111
+
112
+ y_score.append(s_star.cpu().numpy())
113
+
114
+
115
+ best_threshold = np.mean(y_score) + 2 * np.std(y_score)
116
+
117
+ plt.hist(y_score,bins=50)
118
+ plt.vlines(x=best_threshold,ymin=0,ymax=30,color='r')
119
+ plt.show()
120
+
121
+
122
+ y_score = []
123
+ y_true=[]
124
+
125
+ for classes in ['color','good','cut','hole','metal_contamination','thread']:
126
+ folder_path = Path(r'/content/carpet/test/{}'.format(classes))
127
+
128
+ for pth in tqdm(folder_path.iterdir(),leave=False):
129
+
130
+ class_label = pth.parts[-2]
131
+ with torch.no_grad():
132
+ test_image = transform(Image.open(pth)).unsqueeze(0)
133
+ features = backbone(test_image)
134
+
135
+ distances = torch.cdist(features, memory_bank, p=2.0)
136
+ dist_score, dist_score_idxs = torch.min(distances, dim=1)
137
+ s_star = torch.max(dist_score)
138
+ segm_map = dist_score.view(1, 1, 28, 28)
139
+
140
+ y_score.append(s_star.cpu().numpy())
141
+ y_true.append(0 if class_label == 'good' else 1)
142
+
143
+
144
+
145
+ # plotting the y_score values which do not belong to 'good' class
146
+
147
+ y_score_nok = [score for score,true in zip(y_score,y_true) if true==1]
148
+ plt.hist(y_score_nok,bins=50)
149
+ plt.vlines(x=best_threshold,ymin=0,ymax=30,color='r')
150
+ plt.show()
151
+
152
+
153
+ test_image = transform(Image.open(r'/content/carpet/test/color/000.png')).unsqueeze(0)
154
+ features = backbone(test_image)
155
+
156
+ distances = torch.cdist(features, memory_bank, p=2.0)
157
+ dist_score, dist_score_idxs = torch.min(distances, dim=1)
158
+ s_star = torch.max(dist_score)
159
+ segm_map = dist_score.view(1, 1, 28, 28)
160
+
161
+ segm_map = torch.nn.functional.interpolate( # Upscale by bi-linaer interpolation to match the original input resolution
162
+ segm_map,
163
+ size=(224, 224),
164
+ mode='bilinear'
165
+ )
166
+
167
+ plt.imshow(segm_map.cpu().squeeze(), cmap='jet')
168
+
169
+
170
+
171
+
172
+ from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, ConfusionMatrixDisplay, f1_score
173
+
174
+
175
+ # Calculate AUC-ROC score
176
+ auc_roc_score = roc_auc_score(y_true, y_score)
177
+ print("AUC-ROC Score:", auc_roc_score)
178
+
179
+ # Plot ROC curve
180
+ fpr, tpr, thresholds = roc_curve(y_true, y_score)
181
+ plt.figure()
182
+ plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % auc_roc_score)
183
+ plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
184
+ plt.xlabel('False Positive Rate')
185
+ plt.ylabel('True Positive Rate')
186
+ plt.title('Receiver Operating Characteristic (ROC) Curve')
187
+ plt.legend(loc="lower right")
188
+ plt.show()
189
+
190
+ f1_scores = [f1_score(y_true, y_score >= threshold) for threshold in thresholds]
191
+
192
+ # Select the best threshold based on F1 score
193
+ best_threshold = thresholds[np.argmax(f1_scores)]
194
+
195
+ print(f'best_threshold = {best_threshold}')
196
+
197
+ # Generate confusion matrix
198
+ cm = confusion_matrix(y_true, (y_score >= best_threshold).astype(int))
199
+ disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=['OK','NOK'])
200
+ disp.plot()
201
+ plt.show()
202
+
203
+
204
+
205
+ import cv2, time
206
+ from IPython.display import clear_output
207
+
208
+ backbone.eval()
209
+
210
+ import gradio as gr
211
+ import torch
212
+ import numpy as np
213
+ from PIL import Image
214
+ import matplotlib.pyplot as plt
215
+ import io
216
+
217
+ # -----------------
218
+
219
+ def detect_fault(uploaded_image):
220
+ # Convert uploaded image
221
+ test_image = transform(uploaded_image).unsqueeze(0)
222
+
223
+ with torch.no_grad():
224
+ features = backbone(test_image)
225
+
226
+ distances = torch.cdist(features, memory_bank, p=2.0)
227
+ dist_score, dist_score_idxs = torch.min(distances, dim=1)
228
+ s_star = torch.max(dist_score)
229
+ segm_map = dist_score.view(1, 1, 28, 28)
230
+ segm_map = torch.nn.functional.interpolate(
231
+ segm_map,
232
+ size=(224, 224),
233
+ mode='bilinear'
234
+ ).cpu().squeeze().numpy()
235
+
236
+ y_score_image = s_star.cpu().numpy()
237
+ y_pred_image = 1*(y_score_image >= best_threshold)
238
+ class_label = ['Image If OK','Image is Not OK']
239
+
240
+ # --- Plot results ---
241
+ fig, axs = plt.subplots(1, 3, figsize=(15, 5))
242
+
243
+ # Original image
244
+ axs[0].imshow(test_image.squeeze().permute(1,2,0).cpu().numpy())
245
+ axs[0].set_title("Original Image")
246
+ axs[0].axis("off")
247
+
248
+ # Heatmap
249
+ axs[1].imshow(segm_map, cmap='jet')
250
+ axs[1].set_title(f"Anomaly Score: {y_score_image / best_threshold:0.4f}\nPrediction: {class_label[y_pred_image]}")
251
+ axs[1].axis("off")
252
+
253
+ # Segmentation map
254
+ axs[2].imshow((segm_map > best_threshold*1.25), cmap='gray')
255
+ axs[2].set_title("Fault Segmentation Map")
256
+ axs[2].axis("off")
257
+
258
+ # Save plot to image
259
+ buf = io.BytesIO()
260
+ plt.savefig(buf, format="png")
261
+ buf.seek(0)
262
+ result_image = Image.open(buf)
263
+ plt.close(fig)
264
+
265
+ return result_image
266
+
267
+ # Gradio UI
268
+ demo = gr.Interface(
269
+ fn=detect_fault,
270
+ inputs=gr.Image(type="pil", label="Upload Image"),
271
+ outputs=gr.Image(type="pil", label="Detection Result"),
272
+ title="Fault Detection in Images",
273
+ description="Upload an image and the model will detect if there are any faults and show the segmentation map."
274
+ )
275
+
276
+ if __name__ == "__main__":
277
+ demo.launch()
278
+