# datasets/all_classes_dataset.py import os from enum import Enum import PIL import torch from torch.utils.data import Dataset from torchvision import transforms IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] class DatasetSplit(Enum): TRAIN = "train" VAL = "val" TEST = "test" class AllClassesDataset(Dataset): def __init__( self, source, input_size=518, output_size=224, split=DatasetSplit.TEST, external_transform=None, **kwargs, ): """ Initialize the dataset to include all classes. Args: source (str): Path to the root data directory. input_size (int): Input image size for transformations. output_size (int): Output mask size. split (DatasetSplit): Dataset split to use (TRAIN, VAL, TEST). external_transform (callable, optional): External image transformations. **kwargs: Additional keyword arguments. """ super().__init__() self.source = source self.split = split self.classnames_to_use = self.get_all_class_names() self.imgpaths_per_class, self.data_to_iterate = self.get_image_data() if external_transform is None: self.transform_img = transforms.Compose([ transforms.Resize((input_size, input_size)), # transforms.CenterCrop(input_size), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ]) else: self.transform_img = external_transform self.transform_mask = transforms.Compose([ transforms.Resize((output_size, output_size)), # transforms.CenterCrop(output_size), transforms.ToTensor(), ]) self.output_shape = (1, output_size, output_size) def get_all_class_names(self): """ Retrieve all class names (subdirectories) from the source directory. Returns: list: List of class names. """ all_items = os.listdir(self.source) classnames = [ item for item in all_items if os.path.isdir(os.path.join(self.source, item)) ] return classnames def get_image_data(self): """ Collect image paths and corresponding mask paths for all classes. Returns: tuple: (imgpaths_per_class, data_to_iterate) """ imgpaths_per_class = {} maskpaths_per_class = {} for classname in self.classnames_to_use: classpath = os.path.join(self.source, classname, self.split.value) maskpath = os.path.join(self.source, classname, "ground_truth") anomaly_types = os.listdir(classpath) imgpaths_per_class[classname] = {} maskpaths_per_class[classname] = {} for anomaly in anomaly_types: anomaly_path = os.path.join(classpath, anomaly) anomaly_files = sorted(os.listdir(anomaly_path)) imgpaths_per_class[classname][anomaly] = [ os.path.join(anomaly_path, x) for x in anomaly_files ] if self.split == DatasetSplit.TEST and anomaly != "good": anomaly_mask_path = os.path.join(maskpath, anomaly) if os.path.exists(anomaly_mask_path): anomaly_mask_files = sorted(os.listdir(anomaly_mask_path)) maskpaths_per_class[classname][anomaly] = [ os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files ] else: # If mask path does not exist, set to None maskpaths_per_class[classname][anomaly] = [None] * len(anomaly_files) else: maskpaths_per_class[classname]["good"] = [None] * len(anomaly_files) data_to_iterate = [] for classname in sorted(imgpaths_per_class.keys()): for anomaly in sorted(imgpaths_per_class[classname].keys()): for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]): data_tuple = [classname, anomaly, image_path] if self.split == DatasetSplit.TEST and anomaly != "good": mask_path = maskpaths_per_class[classname][anomaly][i] data_tuple.append(mask_path) else: data_tuple.append(None) data_to_iterate.append(data_tuple) return imgpaths_per_class, data_to_iterate def __getitem__(self, idx): classname, anomaly, image_path, mask_path = self.data_to_iterate[idx] try: image = PIL.Image.open(image_path).convert("RGB") except Exception as e: # Return a black image or handle as per your requirement image = PIL.Image.new("RGB", (self.transform_img.transforms[0].size, self.transform_img.transforms[0].size), (0, 0, 0)) image = self.transform_img(image) if self.split == DatasetSplit.TEST and mask_path is not None: try: mask = PIL.Image.open(mask_path).convert("L") mask = self.transform_mask(mask) > 0 except Exception as e: mask = torch.zeros([*self.output_shape]) else: mask = torch.zeros([*self.output_shape]) return { "image": image, # Tensor: [3, H, W] "mask": mask, # Tensor: [1, 17, 17] "is_anomaly": int(anomaly != "good"), "image_path": image_path, } def __len__(self): return len(self.data_to_iterate)