|
|
|
|
| import os
|
| from enum import Enum
|
|
|
| import PIL
|
| import torch
|
| from torch.utils.data import Dataset
|
| from torchvision import transforms
|
|
|
| IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| IMAGENET_STD = [0.229, 0.224, 0.225]
|
|
|
|
|
| class DatasetSplit(Enum):
|
| TRAIN = "train"
|
| VAL = "val"
|
| TEST = "test"
|
|
|
|
|
| class AllClassesDataset(Dataset):
|
| def __init__(
|
| self,
|
| source,
|
| input_size=518,
|
| output_size=224,
|
| split=DatasetSplit.TEST,
|
| external_transform=None,
|
| **kwargs,
|
| ):
|
| """
|
| Initialize the dataset to include all classes.
|
|
|
| Args:
|
| source (str): Path to the root data directory.
|
| input_size (int): Input image size for transformations.
|
| output_size (int): Output mask size.
|
| split (DatasetSplit): Dataset split to use (TRAIN, VAL, TEST).
|
| external_transform (callable, optional): External image transformations.
|
| **kwargs: Additional keyword arguments.
|
| """
|
| super().__init__()
|
| self.source = source
|
| self.split = split
|
| self.classnames_to_use = self.get_all_class_names()
|
|
|
| self.imgpaths_per_class, self.data_to_iterate = self.get_image_data()
|
|
|
| if external_transform is None:
|
| self.transform_img = transforms.Compose([
|
| transforms.Resize((input_size, input_size)),
|
|
|
| transforms.ToTensor(),
|
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
| ])
|
| else:
|
| self.transform_img = external_transform
|
|
|
| self.transform_mask = transforms.Compose([
|
| transforms.Resize((output_size, output_size)),
|
|
|
| transforms.ToTensor(),
|
| ])
|
| self.output_shape = (1, output_size, output_size)
|
|
|
| def get_all_class_names(self):
|
| """
|
| Retrieve all class names (subdirectories) from the source directory.
|
|
|
| Returns:
|
| list: List of class names.
|
| """
|
| all_items = os.listdir(self.source)
|
| classnames = [
|
| item for item in all_items
|
| if os.path.isdir(os.path.join(self.source, item))
|
| ]
|
| return classnames
|
|
|
| def get_image_data(self):
|
| """
|
| Collect image paths and corresponding mask paths for all classes.
|
|
|
| Returns:
|
| tuple: (imgpaths_per_class, data_to_iterate)
|
| """
|
| imgpaths_per_class = {}
|
| maskpaths_per_class = {}
|
|
|
| for classname in self.classnames_to_use:
|
| classpath = os.path.join(self.source, classname, self.split.value)
|
| maskpath = os.path.join(self.source, classname, "ground_truth")
|
| anomaly_types = os.listdir(classpath)
|
|
|
| imgpaths_per_class[classname] = {}
|
| maskpaths_per_class[classname] = {}
|
|
|
| for anomaly in anomaly_types:
|
| anomaly_path = os.path.join(classpath, anomaly)
|
| anomaly_files = sorted(os.listdir(anomaly_path))
|
| imgpaths_per_class[classname][anomaly] = [
|
| os.path.join(anomaly_path, x) for x in anomaly_files
|
| ]
|
|
|
| if self.split == DatasetSplit.TEST and anomaly != "good":
|
| anomaly_mask_path = os.path.join(maskpath, anomaly)
|
| if os.path.exists(anomaly_mask_path):
|
| anomaly_mask_files = sorted(os.listdir(anomaly_mask_path))
|
| maskpaths_per_class[classname][anomaly] = [
|
| os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files
|
| ]
|
| else:
|
|
|
| maskpaths_per_class[classname][anomaly] = [None] * len(anomaly_files)
|
| else:
|
| maskpaths_per_class[classname]["good"] = [None] * len(anomaly_files)
|
|
|
| data_to_iterate = []
|
| for classname in sorted(imgpaths_per_class.keys()):
|
| for anomaly in sorted(imgpaths_per_class[classname].keys()):
|
| for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]):
|
| data_tuple = [classname, anomaly, image_path]
|
| if self.split == DatasetSplit.TEST and anomaly != "good":
|
| mask_path = maskpaths_per_class[classname][anomaly][i]
|
| data_tuple.append(mask_path)
|
| else:
|
| data_tuple.append(None)
|
| data_to_iterate.append(data_tuple)
|
|
|
| return imgpaths_per_class, data_to_iterate
|
|
|
| def __getitem__(self, idx):
|
| classname, anomaly, image_path, mask_path = self.data_to_iterate[idx]
|
| try:
|
| image = PIL.Image.open(image_path).convert("RGB")
|
| except Exception as e:
|
|
|
| image = PIL.Image.new("RGB", (self.transform_img.transforms[0].size, self.transform_img.transforms[0].size), (0, 0, 0))
|
| image = self.transform_img(image)
|
|
|
| if self.split == DatasetSplit.TEST and mask_path is not None:
|
| try:
|
| mask = PIL.Image.open(mask_path).convert("L")
|
| mask = self.transform_mask(mask) > 0
|
| except Exception as e:
|
| mask = torch.zeros([*self.output_shape])
|
| else:
|
| mask = torch.zeros([*self.output_shape])
|
|
|
| return {
|
| "image": image,
|
| "mask": mask,
|
| "is_anomaly": int(anomaly != "good"),
|
| "image_path": image_path,
|
| }
|
|
|
| def __len__(self):
|
| return len(self.data_to_iterate)
|
|
|