gmustafa413 commited on
Commit
42790a4
·
verified ·
1 Parent(s): 4468005

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +274 -0
app.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import numpy as np
3
+ import os, shutil
4
+ import matplotlib.pyplot as plt
5
+
6
+ from PIL import Image
7
+
8
+ from tqdm.auto import tqdm
9
+
10
+ import torch
11
+ import torchvision
12
+ from torch.utils.data import DataLoader
13
+ from torchvision.datasets import ImageFolder
14
+ from torchvision.transforms import transforms
15
+ import torch.optim as optim
16
+
17
+ from torchvision.models import resnet50 , ResNet50_Weights
18
+
19
+ transform = transforms.Compose([
20
+ transforms.Resize((224,224)),
21
+ transforms.ToTensor()
22
+ ])
23
+
24
+ import urllib.request
25
+ urllib.request.urlretrieve("https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420938166-1629953277/transistor.tar.xz",
26
+ "transistor.tar.xz")
27
+
28
+ import tarfile
29
+
30
+ with tarfile.open('transistor.tar.xz') as f:
31
+ f.extractall('.')
32
+
33
+
34
+
35
+ class resnet_feature_extractor(torch.nn.Module):
36
+ def __init__(self):
37
+ """This class extracts the feature maps from a pretrained Resnet model."""
38
+ super(resnet_feature_extractor, self).__init__()
39
+ self.model = resnet50(weights=ResNet50_Weights.DEFAULT)
40
+
41
+ self.model.eval()
42
+ for param in self.model.parameters():
43
+ param.requires_grad = False
44
+
45
+
46
+
47
+ # Hook to extract feature maps
48
+ def hook(module, input, output) -> None:
49
+ """This hook saves the extracted feature map on self.featured."""
50
+ self.features.append(output)
51
+
52
+ self.model.layer2[-1].register_forward_hook(hook)
53
+ self.model.layer3[-1].register_forward_hook(hook)
54
+
55
+ def forward(self, input):
56
+
57
+ self.features = []
58
+ with torch.no_grad():
59
+ _ = self.model(input)
60
+
61
+ self.avg = torch.nn.AvgPool2d(3, stride=1)
62
+ fmap_size = self.features[0].shape[-2] # Feature map sizes h, w
63
+ self.resize = torch.nn.AdaptiveAvgPool2d(fmap_size)
64
+
65
+ resized_maps = [self.resize(self.avg(fmap)) for fmap in self.features]
66
+ patch = torch.cat(resized_maps, 1) # Merge the resized feature maps
67
+ patch = patch.reshape(patch.shape[1], -1).T # Craete a column tensor
68
+
69
+ return patch
70
+
71
+ image = Image.open(r'/content/transistor/test/good/000.png')
72
+ image = transform(image).unsqueeze(0)
73
+
74
+ backbone = resnet_feature_extractor()
75
+ feature = backbone(image)
76
+
77
+ # print(backbone.features[0].shape)
78
+ # print(backbone.features[1].shape)
79
+
80
+ print(feature.shape)
81
+
82
+ # plt.imshow(image[0].permute(1,2,0))
83
+
84
+ memory_bank =[]
85
+
86
+ folder_path = Path(r'/content/transistor/train/good')
87
+
88
+ for pth in tqdm(folder_path.iterdir(),leave=False):
89
+ with torch.no_grad():
90
+ data = transform(Image.open(pth)).unsqueeze(0)
91
+ features = backbone(data)
92
+ memory_bank.append(features.cpu().detach())
93
+
94
+ memory_bank = torch.cat(memory_bank,dim=0)
95
+
96
+ y_score=[]
97
+
98
+ folder_path = Path(r'/content/transistor/train/good')
99
+
100
+ for pth in tqdm(folder_path.iterdir(),leave=False):
101
+ data = transform(Image.open(pth)).unsqueeze(0)
102
+ with torch.no_grad():
103
+ features = backbone(data)
104
+ distances = torch.cdist(features, memory_bank, p=2.0)
105
+ dist_score, dist_score_idxs = torch.min(distances, dim=1)
106
+ s_star = torch.max(dist_score)
107
+ segm_map = dist_score.view(1, 1, 28, 28)
108
+
109
+
110
+ y_score.append(s_star.cpu().numpy())
111
+
112
+
113
+ best_threshold = np.mean(y_score) + 2 * np.std(y_score)
114
+
115
+ plt.hist(y_score,bins=50)
116
+ plt.vlines(x=best_threshold,ymin=0,ymax=30,color='r')
117
+ plt.show()
118
+
119
+
120
+ y_score = []
121
+ y_true=[]
122
+
123
+ for classes in ['bent_lead','good','cut_lead','damaged_case','misplaced']:
124
+ folder_path = Path(r'/content/transistor/test/{}'.format(classes))
125
+
126
+ for pth in tqdm(folder_path.iterdir(),leave=False):
127
+
128
+ class_label = pth.parts[-2]
129
+ with torch.no_grad():
130
+ test_image = transform(Image.open(pth)).unsqueeze(0)
131
+ features = backbone(test_image)
132
+
133
+ distances = torch.cdist(features, memory_bank, p=2.0)
134
+ dist_score, dist_score_idxs = torch.min(distances, dim=1)
135
+ s_star = torch.max(dist_score)
136
+ segm_map = dist_score.view(1, 1, 28, 28)
137
+
138
+ y_score.append(s_star.cpu().numpy())
139
+ y_true.append(0 if class_label == 'good' else 1)
140
+
141
+ # plotting the y_score values which do not belong to 'good' class
142
+
143
+ y_score_nok = [score for score,true in zip(y_score,y_true) if true==1]
144
+ plt.hist(y_score_nok,bins=50)
145
+ plt.vlines(x=best_threshold,ymin=0,ymax=30,color='r')
146
+ plt.show()
147
+
148
+
149
+ test_image = transform(Image.open(r'/content/transistor/test/good/000.png')).unsqueeze(0)
150
+ features = backbone(test_image)
151
+
152
+ distances = torch.cdist(features, memory_bank, p=2.0)
153
+ dist_score, dist_score_idxs = torch.min(distances, dim=1)
154
+ s_star = torch.max(dist_score)
155
+ segm_map = dist_score.view(1, 1, 28, 28)
156
+
157
+ segm_map = torch.nn.functional.interpolate( # Upscale by bi-linaer interpolation to match the original input resolution
158
+ segm_map,
159
+ size=(224, 224),
160
+ mode='bilinear'
161
+ )
162
+
163
+ plt.imshow(segm_map.cpu().squeeze(), cmap='jet')
164
+
165
+
166
+
167
+
168
+ from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, ConfusionMatrixDisplay, f1_score
169
+
170
+
171
+ # Calculate AUC-ROC score
172
+ auc_roc_score = roc_auc_score(y_true, y_score)
173
+ print("AUC-ROC Score:", auc_roc_score)
174
+
175
+ # Plot ROC curve
176
+ fpr, tpr, thresholds = roc_curve(y_true, y_score)
177
+ plt.figure()
178
+ plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % auc_roc_score)
179
+ plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
180
+ plt.xlabel('False Positive Rate')
181
+ plt.ylabel('True Positive Rate')
182
+ plt.title('Receiver Operating Characteristic (ROC) Curve')
183
+ plt.legend(loc="lower right")
184
+ plt.show()
185
+
186
+
187
+ f1_scores = [f1_score(y_true, y_score >= threshold) for threshold in thresholds]
188
+
189
+ # Select the best threshold based on F1 score
190
+ best_threshold = thresholds[np.argmax(f1_scores)]
191
+
192
+ print(f'best_threshold = {best_threshold}')
193
+
194
+ # Generate confusion matrix
195
+ cm = confusion_matrix(y_true, (y_score >= best_threshold).astype(int))
196
+ disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=['OK','NOK'])
197
+ disp.plot()
198
+ plt.show()
199
+
200
+
201
+
202
+ import cv2, time
203
+ from IPython.display import clear_output
204
+
205
+ backbone.eval()
206
+
207
+ import gradio as gr
208
+ import torch
209
+ import numpy as np
210
+ from PIL import Image
211
+ import matplotlib.pyplot as plt
212
+ import io
213
+
214
+ # -----------------
215
+
216
+ def detect_fault(uploaded_image):
217
+ # Convert uploaded image
218
+ test_image = transform(uploaded_image).unsqueeze(0)
219
+
220
+ with torch.no_grad():
221
+ features = backbone(test_image)
222
+
223
+ distances = torch.cdist(features, memory_bank, p=2.0)
224
+ dist_score, dist_score_idxs = torch.min(distances, dim=1)
225
+ s_star = torch.max(dist_score)
226
+ segm_map = dist_score.view(1, 1, 28, 28)
227
+ segm_map = torch.nn.functional.interpolate(
228
+ segm_map,
229
+ size=(224, 224),
230
+ mode='bilinear'
231
+ ).cpu().squeeze().numpy()
232
+
233
+ y_score_image = s_star.cpu().numpy()
234
+ y_pred_image = 1*(y_score_image >= best_threshold)
235
+ class_label = ['Image If OK','Image is Not OK']
236
+
237
+ # --- Plot results ---
238
+ fig, axs = plt.subplots(1, 3, figsize=(15, 5))
239
+
240
+ # Original image
241
+ axs[0].imshow(test_image.squeeze().permute(1,2,0).cpu().numpy())
242
+ axs[0].set_title("Original Image")
243
+ axs[0].axis("off")
244
+
245
+ # Heatmap
246
+ axs[1].imshow(segm_map, cmap='jet')
247
+ axs[1].set_title(f"Anomaly Score: {y_score_image / best_threshold:0.4f}\nPrediction: {class_label[y_pred_image]}")
248
+ axs[1].axis("off")
249
+
250
+ # Segmentation map
251
+ axs[2].imshow((segm_map > best_threshold*1.25), cmap='gray')
252
+ axs[2].set_title("Fault Segmentation Map")
253
+ axs[2].axis("off")
254
+
255
+ # Save plot to image
256
+ buf = io.BytesIO()
257
+ plt.savefig(buf, format="png")
258
+ buf.seek(0)
259
+ result_image = Image.open(buf)
260
+ plt.close(fig)
261
+
262
+ return result_image
263
+
264
+ # Gradio UI
265
+ demo = gr.Interface(
266
+ fn=detect_fault,
267
+ inputs=gr.Image(type="pil", label="Upload Image"),
268
+ outputs=gr.Image(type="pil", label="Detection Result"),
269
+ title="Fault Detection in Images",
270
+ description="Upload an image and the model will detect if there are any faults and show the segmentation map."
271
+ )
272
+
273
+ if __name__ == "__main__":
274
+ demo.launch()