gmustafa413 commited on
Commit
118673f
·
verified ·
1 Parent(s): f8b91b3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -0
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+ import io
7
+
8
+ from pathlib import Path
9
+ import os, shutil
10
+ from tqdm.auto import tqdm
11
+ import torchvision
12
+ from torch.utils.data import DataLoader
13
+ from torchvision.datasets import ImageFolder
14
+ from torchvision.transforms import transforms
15
+ import torch.optim as optim
16
+ from torchvision.models import resnet50, ResNet50_Weights
17
+ import urllib.request
18
+ import tarfile
19
+
20
+ # Transform
21
+ transform = transforms.Compose([
22
+ transforms.Resize((224,224)),
23
+ transforms.ToTensor()
24
+ ])
25
+
26
+ # Dataset download
27
+ urllib.request.urlretrieve(
28
+ "https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420937484-1629951672/carpet.tar.xz",
29
+ "carpet.tar.xz"
30
+ )
31
+
32
+ with tarfile.open('carpet.tar.xz') as f:
33
+ f.extractall('.')
34
+
35
+ # Feature extractor class
36
+ class resnet_feature_extractor(torch.nn.Module):
37
+ def __init__(self):
38
+ super(resnet_feature_extractor, self).__init__()
39
+ self.model = resnet50(weights=ResNet50_Weights.DEFAULT)
40
+ self.model.eval()
41
+ for param in self.model.parameters():
42
+ param.requires_grad = False
43
+
44
+ def hook(module, input, output):
45
+ self.features.append(output)
46
+
47
+ self.model.layer2[-1].register_forward_hook(hook)
48
+ self.model.layer3[-1].register_forward_hook(hook)
49
+
50
+ def forward(self, input):
51
+ self.features = []
52
+ with torch.no_grad():
53
+ _ = self.model(input)
54
+
55
+ self.avg = torch.nn.AvgPool2d(3, stride=1)
56
+ fmap_size = self.features[0].shape[-2]
57
+ self.resize = torch.nn.AdaptiveAvgPool2d(fmap_size)
58
+
59
+ resized_maps = [self.resize(self.avg(fmap)) for fmap in self.features]
60
+ patch = torch.cat(resized_maps, 1)
61
+ patch = patch.reshape(patch.shape[1], -1).T
62
+ return patch
63
+
64
+ # Initialize backbone
65
+ backbone = resnet_feature_extractor()
66
+
67
+ # Memory bank
68
+ memory_bank = []
69
+ folder_path = Path("carpet/train/good")
70
+ for pth in tqdm(folder_path.iterdir(), leave=False):
71
+ with torch.no_grad():
72
+ data = transform(Image.open(pth)).unsqueeze(0)
73
+ features = backbone(data)
74
+ memory_bank.append(features.cpu().detach())
75
+ memory_bank = torch.cat(memory_bank, dim=0)
76
+
77
+ # Threshold
78
+ y_score = []
79
+ for pth in tqdm(folder_path.iterdir(), leave=False):
80
+ data = transform(Image.open(pth)).unsqueeze(0)
81
+ with torch.no_grad():
82
+ features = backbone(data)
83
+ distances = torch.cdist(features, memory_bank, p=2.0)
84
+ dist_score, _ = torch.min(distances, dim=1)
85
+ s_star = torch.max(dist_score)
86
+ y_score.append(s_star.cpu().numpy())
87
+
88
+ best_threshold = np.mean(y_score) + 2 * np.std(y_score)
89
+
90
+
91
+ # Gradio Function
92
+
93
+ def detect_fault(uploaded_image):
94
+ test_image = transform(uploaded_image).unsqueeze(0)
95
+
96
+ with torch.no_grad():
97
+ features = backbone(test_image)
98
+
99
+ distances = torch.cdist(features, memory_bank, p=2.0)
100
+ dist_score, _ = torch.min(distances, dim=1)
101
+ s_star = torch.max(dist_score)
102
+ segm_map = dist_score.view(1, 1, 28, 28)
103
+ segm_map = torch.nn.functional.interpolate(
104
+ segm_map,
105
+ size=(224, 224),
106
+ mode='bilinear'
107
+ ).cpu().squeeze().numpy()
108
+
109
+ y_score_image = s_star.cpu().numpy()
110
+ y_pred_image = 1*(y_score_image >= best_threshold)
111
+ class_label = ['Image Is OK','Image Is Not OK']
112
+
113
+ # Plot results
114
+ fig, axs = plt.subplots(1, 3, figsize=(15, 5))
115
+ axs[0].imshow(test_image.squeeze().permute(1,2,0).cpu().numpy())
116
+ axs[0].set_title("Original Image")
117
+ axs[0].axis("off")
118
+
119
+ axs[1].imshow(segm_map, cmap='jet')
120
+ axs[1].set_title(f"Anomaly Score: {y_score_image / best_threshold:0.4f}\nPrediction: {class_label[y_pred_image]}")
121
+ axs[1].axis("off")
122
+
123
+ axs[2].imshow((segm_map > best_threshold*1.25), cmap='gray')
124
+ axs[2].set_title("Fault Segmentation Map")
125
+ axs[2].axis("off")
126
+
127
+ buf = io.BytesIO()
128
+ plt.savefig(buf, format="png")
129
+ buf.seek(0)
130
+ result_image = Image.open(buf)
131
+ plt.close(fig)
132
+
133
+ return result_image
134
+
135
+ # Launch Gradio App
136
+ demo = gr.Interface(
137
+ fn=detect_fault,
138
+ inputs=gr.Image(type="pil", label="Upload Image"),
139
+ outputs=gr.Image(type="pil", label="Detection Result"),
140
+ title="Fault Detection in Images",
141
+ description="Upload an image and the model will detect if there are any faults and show the segmentation map."
142
+ )
143
+
144
+ if __name__ == "__main__":
145
+ demo.launch()