Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import io | |
| from pathlib import Path | |
| import os, shutil | |
| from tqdm.auto import tqdm | |
| import torchvision | |
| from torch.utils.data import DataLoader | |
| from torchvision.datasets import ImageFolder | |
| from torchvision.transforms import transforms | |
| import torch.optim as optim | |
| from torchvision.models import resnet50, ResNet50_Weights | |
| import urllib.request | |
| import tarfile | |
| # Transform | |
| transform = transforms.Compose([ | |
| transforms.Resize((224,224)), | |
| transforms.ToTensor() | |
| ]) | |
| # Dataset download | |
| urllib.request.urlretrieve( | |
| "https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420937484-1629951672/carpet.tar.xz", | |
| "carpet.tar.xz" | |
| ) | |
| with tarfile.open('carpet.tar.xz') as f: | |
| f.extractall('.') | |
| # Feature extractor class | |
| class resnet_feature_extractor(torch.nn.Module): | |
| def __init__(self): | |
| super(resnet_feature_extractor, self).__init__() | |
| self.model = resnet50(weights=ResNet50_Weights.DEFAULT) | |
| self.model.eval() | |
| for param in self.model.parameters(): | |
| param.requires_grad = False | |
| def hook(module, input, output): | |
| self.features.append(output) | |
| self.model.layer2[-1].register_forward_hook(hook) | |
| self.model.layer3[-1].register_forward_hook(hook) | |
| def forward(self, input): | |
| self.features = [] | |
| with torch.no_grad(): | |
| _ = self.model(input) | |
| self.avg = torch.nn.AvgPool2d(3, stride=1) | |
| fmap_size = self.features[0].shape[-2] | |
| self.resize = torch.nn.AdaptiveAvgPool2d(fmap_size) | |
| resized_maps = [self.resize(self.avg(fmap)) for fmap in self.features] | |
| patch = torch.cat(resized_maps, 1) | |
| patch = patch.reshape(patch.shape[1], -1).T | |
| return patch | |
| # Initialize backbone | |
| backbone = resnet_feature_extractor() | |
| # Memory bank | |
| memory_bank = [] | |
| folder_path = Path("carpet/train/good") | |
| for pth in tqdm(folder_path.iterdir(), leave=False): | |
| with torch.no_grad(): | |
| data = transform(Image.open(pth)).unsqueeze(0) | |
| features = backbone(data) | |
| memory_bank.append(features.cpu().detach()) | |
| memory_bank = torch.cat(memory_bank, dim=0) | |
| # Threshold | |
| y_score = [] | |
| for pth in tqdm(folder_path.iterdir(), leave=False): | |
| data = transform(Image.open(pth)).unsqueeze(0) | |
| with torch.no_grad(): | |
| features = backbone(data) | |
| distances = torch.cdist(features, memory_bank, p=2.0) | |
| dist_score, _ = torch.min(distances, dim=1) | |
| s_star = torch.max(dist_score) | |
| y_score.append(s_star.cpu().numpy()) | |
| best_threshold = np.mean(y_score) + 2 * np.std(y_score) | |
| # Gradio Function | |
| def detect_fault(uploaded_image): | |
| test_image = transform(uploaded_image).unsqueeze(0) | |
| with torch.no_grad(): | |
| features = backbone(test_image) | |
| distances = torch.cdist(features, memory_bank, p=2.0) | |
| dist_score, _ = torch.min(distances, dim=1) | |
| s_star = torch.max(dist_score) | |
| segm_map = dist_score.view(1, 1, 28, 28) | |
| segm_map = torch.nn.functional.interpolate( | |
| segm_map, | |
| size=(224, 224), | |
| mode='bilinear' | |
| ).cpu().squeeze().numpy() | |
| y_score_image = s_star.cpu().numpy() | |
| y_pred_image = 1*(y_score_image >= best_threshold) | |
| class_label = ['Image Is OK','Image Is Not OK'] | |
| # Plot results | |
| fig, axs = plt.subplots(1, 3, figsize=(15, 5)) | |
| axs[0].imshow(test_image.squeeze().permute(1,2,0).cpu().numpy()) | |
| axs[0].set_title("Original Image") | |
| axs[0].axis("off") | |
| axs[1].imshow(segm_map, cmap='jet') | |
| axs[1].set_title(f"Anomaly Score: {y_score_image / best_threshold:0.4f}\nPrediction: {class_label[y_pred_image]}") | |
| axs[1].axis("off") | |
| axs[2].imshow((segm_map > best_threshold*1.25), cmap='gray') | |
| axs[2].set_title("Fault Segmentation Map") | |
| axs[2].axis("off") | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png") | |
| buf.seek(0) | |
| result_image = Image.open(buf) | |
| plt.close(fig) | |
| return result_image | |
| # Launch Gradio App | |
| demo = gr.Interface( | |
| fn=detect_fault, | |
| inputs=gr.Image(type="pil", label="Upload Image"), | |
| outputs=gr.Image(type="pil", label="Detection Result"), | |
| title="Fault Detection in Images", | |
| description="Upload an image and the model will detect if there are any faults and show the segmentation map." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |