Spaces:
Build error
Build error
| """MVTec AD Dataset (CC BY-NC-SA 4.0). | |
| Description: | |
| This script contains PyTorch Dataset, Dataloader and PyTorch | |
| Lightning DataModule for the MVTec AD dataset. | |
| If the dataset is not on the file system, the script downloads and | |
| extracts the dataset and create PyTorch data objects. | |
| License: | |
| MVTec AD dataset is released under the Creative Commons | |
| Attribution-NonCommercial-ShareAlike 4.0 International License | |
| (CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/). | |
| Reference: | |
| - Paul Bergmann, Kilian Batzner, Michael Fauser, David Sattlegger, Carsten Steger: | |
| The MVTec Anomaly Detection Dataset: A Comprehensive Real-World Dataset for | |
| Unsupervised Anomaly Detection; in: International Journal of Computer Vision | |
| 129(4):1038-1059, 2021, DOI: 10.1007/s11263-020-01400-4. | |
| - Paul Bergmann, Michael Fauser, David Sattlegger, Carsten Steger: MVTec AD — | |
| A Comprehensive Real-World Dataset for Unsupervised Anomaly Detection; | |
| in: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), | |
| 9584-9592, 2019, DOI: 10.1109/CVPR.2019.00982. | |
| """ | |
| # Copyright (C) 2020 Intel Corporation | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, | |
| # software distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions | |
| # and limitations under the License. | |
| import logging | |
| import tarfile | |
| from pathlib import Path | |
| from typing import Dict, Optional, Tuple, Union | |
| from urllib.request import urlretrieve | |
| import albumentations as A | |
| import cv2 | |
| import numpy as np | |
| import pandas as pd | |
| from pandas.core.frame import DataFrame | |
| from pytorch_lightning.core.datamodule import LightningDataModule | |
| from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS | |
| from torch import Tensor | |
| from torch.utils.data import DataLoader | |
| from torch.utils.data.dataset import Dataset | |
| from torchvision.datasets.folder import VisionDataset | |
| from anomalib.data.inference import InferenceDataset | |
| from anomalib.data.utils import DownloadProgressBar, read_image | |
| from anomalib.data.utils.split import ( | |
| create_validation_set_from_test_set, | |
| split_normal_images_in_train_set, | |
| ) | |
| from anomalib.pre_processing import PreProcessor | |
| logger = logging.getLogger(__name__) | |
| def make_mvtec_dataset( | |
| path: Path, | |
| split: Optional[str] = None, | |
| split_ratio: float = 0.1, | |
| seed: int = 0, | |
| create_validation_set: bool = False, | |
| ) -> DataFrame: | |
| """Create MVTec AD samples by parsing the MVTec AD data file structure. | |
| The files are expected to follow the structure: | |
| path/to/dataset/split/category/image_filename.png | |
| path/to/dataset/ground_truth/category/mask_filename.png | |
| This function creates a dataframe to store the parsed information based on the following format: | |
| |---|---------------|-------|---------|---------------|---------------------------------------|-------------| | |
| | | path | split | label | image_path | mask_path | label_index | | |
| |---|---------------|-------|---------|---------------|---------------------------------------|-------------| | |
| | 0 | datasets/name | test | defect | filename.png | ground_truth/defect/filename_mask.png | 1 | | |
| |---|---------------|-------|---------|---------------|---------------------------------------|-------------| | |
| Args: | |
| path (Path): Path to dataset | |
| split (str, optional): Dataset split (ie., either train or test). Defaults to None. | |
| split_ratio (float, optional): Ratio to split normal training images and add to the | |
| test set in case test set doesn't contain any normal images. | |
| Defaults to 0.1. | |
| seed (int, optional): Random seed to ensure reproducibility when splitting. Defaults to 0. | |
| create_validation_set (bool, optional): Boolean to create a validation set from the test set. | |
| MVTec AD dataset does not contain a validation set. Those wanting to create a validation set | |
| could set this flag to ``True``. | |
| Example: | |
| The following example shows how to get training samples from MVTec AD bottle category: | |
| >>> root = Path('./MVTec') | |
| >>> category = 'bottle' | |
| >>> path = root / category | |
| >>> path | |
| PosixPath('MVTec/bottle') | |
| >>> samples = make_mvtec_dataset(path, split='train', split_ratio=0.1, seed=0) | |
| >>> samples.head() | |
| path split label image_path mask_path label_index | |
| 0 MVTec/bottle train good MVTec/bottle/train/good/105.png MVTec/bottle/ground_truth/good/105_mask.png 0 | |
| 1 MVTec/bottle train good MVTec/bottle/train/good/017.png MVTec/bottle/ground_truth/good/017_mask.png 0 | |
| 2 MVTec/bottle train good MVTec/bottle/train/good/137.png MVTec/bottle/ground_truth/good/137_mask.png 0 | |
| 3 MVTec/bottle train good MVTec/bottle/train/good/152.png MVTec/bottle/ground_truth/good/152_mask.png 0 | |
| 4 MVTec/bottle train good MVTec/bottle/train/good/109.png MVTec/bottle/ground_truth/good/109_mask.png 0 | |
| Returns: | |
| DataFrame: an output dataframe containing samples for the requested split (ie., train or test) | |
| """ | |
| samples_list = [(str(path),) + filename.parts[-3:] for filename in path.glob("**/*.png")] | |
| if len(samples_list) == 0: | |
| raise RuntimeError(f"Found 0 images in {path}") | |
| samples = pd.DataFrame(samples_list, columns=["path", "split", "label", "image_path"]) | |
| samples = samples[samples.split != "ground_truth"] | |
| # Create mask_path column | |
| samples["mask_path"] = ( | |
| samples.path | |
| + "/ground_truth/" | |
| + samples.label | |
| + "/" | |
| + samples.image_path.str.rstrip("png").str.rstrip(".") | |
| + "_mask.png" | |
| ) | |
| # Modify image_path column by converting to absolute path | |
| samples["image_path"] = samples.path + "/" + samples.split + "/" + samples.label + "/" + samples.image_path | |
| # Split the normal images in training set if test set doesn't | |
| # contain any normal images. This is needed because AUC score | |
| # cannot be computed based on 1-class | |
| if sum((samples.split == "test") & (samples.label == "good")) == 0: | |
| samples = split_normal_images_in_train_set(samples, split_ratio, seed) | |
| # Good images don't have mask | |
| samples.loc[(samples.split == "test") & (samples.label == "good"), "mask_path"] = "" | |
| # Create label index for normal (0) and anomalous (1) images. | |
| samples.loc[(samples.label == "good"), "label_index"] = 0 | |
| samples.loc[(samples.label != "good"), "label_index"] = 1 | |
| samples.label_index = samples.label_index.astype(int) | |
| if create_validation_set: | |
| samples = create_validation_set_from_test_set(samples, seed=seed) | |
| # Get the data frame for the split. | |
| if split is not None and split in ["train", "val", "test"]: | |
| samples = samples[samples.split == split] | |
| samples = samples.reset_index(drop=True) | |
| return samples | |
| class MVTec(VisionDataset): | |
| """MVTec AD PyTorch Dataset.""" | |
| def __init__( | |
| self, | |
| root: Union[Path, str], | |
| category: str, | |
| pre_process: PreProcessor, | |
| split: str, | |
| task: str = "segmentation", | |
| seed: int = 0, | |
| create_validation_set: bool = False, | |
| ) -> None: | |
| """Mvtec AD Dataset class. | |
| Args: | |
| root: Path to the MVTec AD dataset | |
| category: Name of the MVTec AD category. | |
| pre_process: List of pre_processing object containing albumentation compose. | |
| split: 'train', 'val' or 'test' | |
| task: ``classification`` or ``segmentation`` | |
| seed: seed used for the random subset splitting | |
| create_validation_set: Create a validation subset in addition to the train and test subsets | |
| Examples: | |
| >>> from anomalib.data.mvtec import MVTec | |
| >>> from anomalib.data.transforms import PreProcessor | |
| >>> pre_process = PreProcessor(image_size=256) | |
| >>> dataset = MVTec( | |
| ... root='./datasets/MVTec', | |
| ... category='leather', | |
| ... pre_process=pre_process, | |
| ... task="classification", | |
| ... is_train=True, | |
| ... ) | |
| >>> dataset[0].keys() | |
| dict_keys(['image']) | |
| >>> dataset.split = "test" | |
| >>> dataset[0].keys() | |
| dict_keys(['image', 'image_path', 'label']) | |
| >>> dataset.task = "segmentation" | |
| >>> dataset.split = "train" | |
| >>> dataset[0].keys() | |
| dict_keys(['image']) | |
| >>> dataset.split = "test" | |
| >>> dataset[0].keys() | |
| dict_keys(['image_path', 'label', 'mask_path', 'image', 'mask']) | |
| >>> dataset[0]["image"].shape, dataset[0]["mask"].shape | |
| (torch.Size([3, 256, 256]), torch.Size([256, 256])) | |
| """ | |
| super().__init__(root) | |
| self.root = Path(root) if isinstance(root, str) else root | |
| self.category: str = category | |
| self.split = split | |
| self.task = task | |
| self.pre_process = pre_process | |
| self.samples = make_mvtec_dataset( | |
| path=self.root / category, | |
| split=self.split, | |
| seed=seed, | |
| create_validation_set=create_validation_set, | |
| ) | |
| def __len__(self) -> int: | |
| """Get length of the dataset.""" | |
| return len(self.samples) | |
| def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]: | |
| """Get dataset item for the index ``index``. | |
| Args: | |
| index (int): Index to get the item. | |
| Returns: | |
| Union[Dict[str, Tensor], Dict[str, Union[str, Tensor]]]: Dict of image tensor during training. | |
| Otherwise, Dict containing image path, target path, image tensor, label and transformed bounding box. | |
| """ | |
| item: Dict[str, Union[str, Tensor]] = {} | |
| image_path = self.samples.image_path[index] | |
| image = read_image(image_path) | |
| pre_processed = self.pre_process(image=image) | |
| item = {"image": pre_processed["image"]} | |
| if self.split in ["val", "test"]: | |
| label_index = self.samples.label_index[index] | |
| item["image_path"] = image_path | |
| item["label"] = label_index | |
| if self.task == "segmentation": | |
| mask_path = self.samples.mask_path[index] | |
| # Only Anomalous (1) images has masks in MVTec AD dataset. | |
| # Therefore, create empty mask for Normal (0) images. | |
| if label_index == 0: | |
| mask = np.zeros(shape=image.shape[:2]) | |
| else: | |
| mask = cv2.imread(mask_path, flags=0) / 255.0 | |
| pre_processed = self.pre_process(image=image, mask=mask) | |
| item["mask_path"] = mask_path | |
| item["image"] = pre_processed["image"] | |
| item["mask"] = pre_processed["mask"] | |
| return item | |
| class MVTecDataModule(LightningDataModule): | |
| """MVTec AD Lightning Data Module.""" | |
| def __init__( | |
| self, | |
| root: str, | |
| category: str, | |
| # TODO: Remove default values. IAAALD-211 | |
| image_size: Optional[Union[int, Tuple[int, int]]] = None, | |
| train_batch_size: int = 32, | |
| test_batch_size: int = 32, | |
| num_workers: int = 8, | |
| task: str = "segmentation", | |
| transform_config_train: Optional[Union[str, A.Compose]] = None, | |
| transform_config_val: Optional[Union[str, A.Compose]] = None, | |
| seed: int = 0, | |
| create_validation_set: bool = False, | |
| ) -> None: | |
| """Mvtec AD Lightning Data Module. | |
| Args: | |
| root: Path to the MVTec AD dataset | |
| category: Name of the MVTec AD category. | |
| image_size: Variable to which image is resized. | |
| train_batch_size: Training batch size. | |
| test_batch_size: Testing batch size. | |
| num_workers: Number of workers. | |
| task: ``classification`` or ``segmentation`` | |
| transform_config_train: Config for pre-processing during training. | |
| transform_config_val: Config for pre-processing during validation. | |
| seed: seed used for the random subset splitting | |
| create_validation_set: Create a validation subset in addition to the train and test subsets | |
| Examples | |
| >>> from anomalib.data import MVTecDataModule | |
| >>> datamodule = MVTecDataModule( | |
| ... root="./datasets/MVTec", | |
| ... category="leather", | |
| ... image_size=256, | |
| ... train_batch_size=32, | |
| ... test_batch_size=32, | |
| ... num_workers=8, | |
| ... transform_config_train=None, | |
| ... transform_config_val=None, | |
| ... ) | |
| >>> datamodule.setup() | |
| >>> i, data = next(enumerate(datamodule.train_dataloader())) | |
| >>> data.keys() | |
| dict_keys(['image']) | |
| >>> data["image"].shape | |
| torch.Size([32, 3, 256, 256]) | |
| >>> i, data = next(enumerate(datamodule.val_dataloader())) | |
| >>> data.keys() | |
| dict_keys(['image_path', 'label', 'mask_path', 'image', 'mask']) | |
| >>> data["image"].shape, data["mask"].shape | |
| (torch.Size([32, 3, 256, 256]), torch.Size([32, 256, 256])) | |
| """ | |
| super().__init__() | |
| self.root = root if isinstance(root, Path) else Path(root) | |
| self.category = category | |
| self.dataset_path = self.root / self.category | |
| self.transform_config_train = transform_config_train | |
| self.transform_config_val = transform_config_val | |
| self.image_size = image_size | |
| if self.transform_config_train is not None and self.transform_config_val is None: | |
| self.transform_config_val = self.transform_config_train | |
| self.pre_process_train = PreProcessor(config=self.transform_config_train, image_size=self.image_size) | |
| self.pre_process_val = PreProcessor(config=self.transform_config_val, image_size=self.image_size) | |
| self.train_batch_size = train_batch_size | |
| self.test_batch_size = test_batch_size | |
| self.num_workers = num_workers | |
| self.create_validation_set = create_validation_set | |
| self.task = task | |
| self.seed = seed | |
| self.train_data: Dataset | |
| self.test_data: Dataset | |
| if create_validation_set: | |
| self.val_data: Dataset | |
| self.inference_data: Dataset | |
| def prepare_data(self) -> None: | |
| """Download the dataset if not available.""" | |
| if (self.root / self.category).is_dir(): | |
| logger.info("Found the dataset.") | |
| else: | |
| self.root.mkdir(parents=True, exist_ok=True) | |
| logger.info("Downloading the Mvtec AD dataset.") | |
| url = "https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420938113-1629952094" | |
| dataset_name = "mvtec_anomaly_detection.tar.xz" | |
| with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc="MVTec AD") as progress_bar: | |
| urlretrieve( | |
| url=f"{url}/{dataset_name}", | |
| filename=self.root / dataset_name, | |
| reporthook=progress_bar.update_to, | |
| ) | |
| logger.info("Extracting the dataset.") | |
| with tarfile.open(self.root / dataset_name) as tar_file: | |
| tar_file.extractall(self.root) | |
| logger.info("Cleaning the tar file") | |
| (self.root / dataset_name).unlink() | |
| def setup(self, stage: Optional[str] = None) -> None: | |
| """Setup train, validation and test data. | |
| Args: | |
| stage: Optional[str]: Train/Val/Test stages. (Default value = None) | |
| """ | |
| logger.info("Setting up train, validation, test and prediction datasets.") | |
| if stage in (None, "fit"): | |
| self.train_data = MVTec( | |
| root=self.root, | |
| category=self.category, | |
| pre_process=self.pre_process_train, | |
| split="train", | |
| task=self.task, | |
| seed=self.seed, | |
| create_validation_set=self.create_validation_set, | |
| ) | |
| if self.create_validation_set: | |
| self.val_data = MVTec( | |
| root=self.root, | |
| category=self.category, | |
| pre_process=self.pre_process_val, | |
| split="val", | |
| task=self.task, | |
| seed=self.seed, | |
| create_validation_set=self.create_validation_set, | |
| ) | |
| self.test_data = MVTec( | |
| root=self.root, | |
| category=self.category, | |
| pre_process=self.pre_process_val, | |
| split="test", | |
| task=self.task, | |
| seed=self.seed, | |
| create_validation_set=self.create_validation_set, | |
| ) | |
| if stage == "predict": | |
| self.inference_data = InferenceDataset( | |
| path=self.root, image_size=self.image_size, transform_config=self.transform_config_val | |
| ) | |
| def train_dataloader(self) -> TRAIN_DATALOADERS: | |
| """Get train dataloader.""" | |
| return DataLoader(self.train_data, shuffle=True, batch_size=self.train_batch_size, num_workers=self.num_workers) | |
| def val_dataloader(self) -> EVAL_DATALOADERS: | |
| """Get validation dataloader.""" | |
| dataset = self.val_data if self.create_validation_set else self.test_data | |
| return DataLoader(dataset=dataset, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers) | |
| def test_dataloader(self) -> EVAL_DATALOADERS: | |
| """Get test dataloader.""" | |
| return DataLoader(self.test_data, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers) | |
| def predict_dataloader(self) -> EVAL_DATALOADERS: | |
| """Get predict dataloader.""" | |
| return DataLoader( | |
| self.inference_data, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers | |
| ) | |