gmustafa413 commited on
Commit
388e8e1
·
verified ·
1 Parent(s): bf646d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -279
app.py CHANGED
@@ -1,279 +0,0 @@
1
- from pathlib import Path
2
- import numpy as np
3
- import os, shutil
4
- import matplotlib.pyplot as plt
5
-
6
- from PIL import Image
7
-
8
- from tqdm.auto import tqdm
9
-
10
- import torch
11
- import torchvision
12
- from torch.utils.data import DataLoader
13
- from torchvision.datasets import ImageFolder
14
- from torchvision.transforms import transforms
15
- import torch.optim as optim
16
-
17
- from torchvision.models import resnet50, ResNet50_Weights
18
-
19
-
20
- transform = transforms.Compose([
21
- transforms.Resize((224,224)),
22
- transforms.ToTensor()
23
- ])
24
-
25
- import urllib.request
26
- urllib.request.urlretrieve("https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420937484-1629951672/carpet.tar.xz",
27
- "carpet.tar.xz")
28
-
29
- import tarfile
30
-
31
- with tarfile.open('carpet.tar.xz') as f:
32
- f.extractall('.')
33
-
34
-
35
- class resnet_feature_extractor(torch.nn.Module):
36
- def __init__(self):
37
- """This class extracts the feature maps from a pretrained Resnet model."""
38
- super(resnet_feature_extractor, self).__init__()
39
- self.model = resnet50(weights=ResNet50_Weights.DEFAULT)
40
-
41
- self.model.eval()
42
- for param in self.model.parameters():
43
- param.requires_grad = False
44
-
45
-
46
-
47
- # Hook to extract feature maps
48
- def hook(module, input, output) -> None:
49
- """This hook saves the extracted feature map on self.featured."""
50
- self.features.append(output)
51
-
52
- self.model.layer2[-1].register_forward_hook(hook)
53
- self.model.layer3[-1].register_forward_hook(hook)
54
-
55
- def forward(self, input):
56
-
57
- self.features = []
58
- with torch.no_grad():
59
- _ = self.model(input)
60
-
61
- self.avg = torch.nn.AvgPool2d(3, stride=1)
62
- fmap_size = self.features[0].shape[-2] # Feature map sizes h, w
63
- self.resize = torch.nn.AdaptiveAvgPool2d(fmap_size)
64
-
65
- resized_maps = [self.resize(self.avg(fmap)) for fmap in self.features]
66
- patch = torch.cat(resized_maps, 1) # Merge the resized feature maps
67
- patch = patch.reshape(patch.shape[1], -1).T # Craete a column tensor
68
-
69
- return patch
70
-
71
-
72
-
73
- image = Image.open(r'/content/carpet/test/color/000.png')
74
- image = transform(image).unsqueeze(0)
75
-
76
- backbone = resnet_feature_extractor()
77
- feature = backbone(image)
78
-
79
- # print(backbone.features[0].shape)
80
- # print(backbone.features[1].shape)
81
-
82
- print(feature.shape)
83
-
84
- # plt.imshow(image[0].permute(1,2,0))
85
-
86
- memory_bank =[]
87
-
88
- folder_path = Path(r'/content/carpet/train/good')
89
-
90
- for pth in tqdm(folder_path.iterdir(),leave=False):
91
- with torch.no_grad():
92
- data = transform(Image.open(pth)).unsqueeze(0)
93
- features = backbone(data)
94
- memory_bank.append(features.cpu().detach())
95
-
96
- memory_bank = torch.cat(memory_bank,dim=0)
97
-
98
- y_score=[]
99
-
100
- folder_path = Path(r'/content/carpet/train/good')
101
-
102
- for pth in tqdm(folder_path.iterdir(),leave=False):
103
- data = transform(Image.open(pth)).unsqueeze(0)
104
- with torch.no_grad():
105
- features = backbone(data)
106
- distances = torch.cdist(features, memory_bank, p=2.0)
107
- dist_score, dist_score_idxs = torch.min(distances, dim=1)
108
- s_star = torch.max(dist_score)
109
- segm_map = dist_score.view(1, 1, 28, 28)
110
-
111
-
112
- y_score.append(s_star.cpu().numpy())
113
-
114
-
115
- best_threshold = np.mean(y_score) + 2 * np.std(y_score)
116
-
117
- plt.hist(y_score,bins=50)
118
- plt.vlines(x=best_threshold,ymin=0,ymax=30,color='r')
119
- plt.show()
120
-
121
-
122
- y_score = []
123
- y_true=[]
124
-
125
- for classes in ['color','good','cut','hole','metal_contamination','thread']:
126
- folder_path = Path(r'/content/carpet/test/{}'.format(classes))
127
-
128
- for pth in tqdm(folder_path.iterdir(),leave=False):
129
-
130
- class_label = pth.parts[-2]
131
- with torch.no_grad():
132
- test_image = transform(Image.open(pth)).unsqueeze(0)
133
- features = backbone(test_image)
134
-
135
- distances = torch.cdist(features, memory_bank, p=2.0)
136
- dist_score, dist_score_idxs = torch.min(distances, dim=1)
137
- s_star = torch.max(dist_score)
138
- segm_map = dist_score.view(1, 1, 28, 28)
139
-
140
- y_score.append(s_star.cpu().numpy())
141
- y_true.append(0 if class_label == 'good' else 1)
142
-
143
-
144
-
145
- # plotting the y_score values which do not belong to 'good' class
146
-
147
- y_score_nok = [score for score,true in zip(y_score,y_true) if true==1]
148
- plt.hist(y_score_nok,bins=50)
149
- plt.vlines(x=best_threshold,ymin=0,ymax=30,color='r')
150
- plt.show()
151
-
152
-
153
- test_image = transform(Image.open(r'/content/carpet/test/color/000.png')).unsqueeze(0)
154
- features = backbone(test_image)
155
-
156
- distances = torch.cdist(features, memory_bank, p=2.0)
157
- dist_score, dist_score_idxs = torch.min(distances, dim=1)
158
- s_star = torch.max(dist_score)
159
- segm_map = dist_score.view(1, 1, 28, 28)
160
-
161
- segm_map = torch.nn.functional.interpolate( # Upscale by bi-linaer interpolation to match the original input resolution
162
- segm_map,
163
- size=(224, 224),
164
- mode='bilinear'
165
- )
166
-
167
- plt.imshow(segm_map.cpu().squeeze(), cmap='jet')
168
-
169
-
170
-
171
-
172
- from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, ConfusionMatrixDisplay, f1_score
173
-
174
-
175
- # Calculate AUC-ROC score
176
- auc_roc_score = roc_auc_score(y_true, y_score)
177
- print("AUC-ROC Score:", auc_roc_score)
178
-
179
- # Plot ROC curve
180
- fpr, tpr, thresholds = roc_curve(y_true, y_score)
181
- plt.figure()
182
- plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % auc_roc_score)
183
- plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
184
- plt.xlabel('False Positive Rate')
185
- plt.ylabel('True Positive Rate')
186
- plt.title('Receiver Operating Characteristic (ROC) Curve')
187
- plt.legend(loc="lower right")
188
- plt.show()
189
-
190
- f1_scores = [f1_score(y_true, y_score >= threshold) for threshold in thresholds]
191
-
192
- # Select the best threshold based on F1 score
193
- best_threshold = thresholds[np.argmax(f1_scores)]
194
-
195
- print(f'best_threshold = {best_threshold}')
196
-
197
- # Generate confusion matrix
198
- cm = confusion_matrix(y_true, (y_score >= best_threshold).astype(int))
199
- disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=['OK','NOK'])
200
- disp.plot()
201
- plt.show()
202
-
203
-
204
-
205
- import cv2, time
206
- from IPython.display import clear_output
207
-
208
- backbone.eval()
209
-
210
- import gradio as gr
211
- import torch
212
- import numpy as np
213
- from PIL import Image
214
- import matplotlib.pyplot as plt
215
- import io
216
-
217
- # -----------------
218
-
219
- def detect_fault(uploaded_image):
220
- # Convert uploaded image
221
- test_image = transform(uploaded_image).unsqueeze(0)
222
-
223
- with torch.no_grad():
224
- features = backbone(test_image)
225
-
226
- distances = torch.cdist(features, memory_bank, p=2.0)
227
- dist_score, dist_score_idxs = torch.min(distances, dim=1)
228
- s_star = torch.max(dist_score)
229
- segm_map = dist_score.view(1, 1, 28, 28)
230
- segm_map = torch.nn.functional.interpolate(
231
- segm_map,
232
- size=(224, 224),
233
- mode='bilinear'
234
- ).cpu().squeeze().numpy()
235
-
236
- y_score_image = s_star.cpu().numpy()
237
- y_pred_image = 1*(y_score_image >= best_threshold)
238
- class_label = ['Image If OK','Image is Not OK']
239
-
240
- # --- Plot results ---
241
- fig, axs = plt.subplots(1, 3, figsize=(15, 5))
242
-
243
- # Original image
244
- axs[0].imshow(test_image.squeeze().permute(1,2,0).cpu().numpy())
245
- axs[0].set_title("Original Image")
246
- axs[0].axis("off")
247
-
248
- # Heatmap
249
- axs[1].imshow(segm_map, cmap='jet')
250
- axs[1].set_title(f"Anomaly Score: {y_score_image / best_threshold:0.4f}\nPrediction: {class_label[y_pred_image]}")
251
- axs[1].axis("off")
252
-
253
- # Segmentation map
254
- axs[2].imshow((segm_map > best_threshold*1.25), cmap='gray')
255
- axs[2].set_title("Fault Segmentation Map")
256
- axs[2].axis("off")
257
-
258
- # Save plot to image
259
- buf = io.BytesIO()
260
- plt.savefig(buf, format="png")
261
- buf.seek(0)
262
- result_image = Image.open(buf)
263
- plt.close(fig)
264
-
265
- return result_image
266
-
267
- # Gradio UI
268
- demo = gr.Interface(
269
- fn=detect_fault,
270
- inputs=gr.Image(type="pil", label="Upload Image"),
271
- outputs=gr.Image(type="pil", label="Detection Result"),
272
- title="Fault Detection in Images",
273
- description="Upload an image and the model will detect if there are any faults and show the segmentation map."
274
- )
275
-
276
- if __name__ == "__main__":
277
- demo.launch()
278
-
279
-