ZeroShot-AD / datasets /all_classes_dataset.py
HoomKh's picture
files
e5461d8 verified
# datasets/all_classes_dataset.py
import os
from enum import Enum
import PIL
import torch
from torch.utils.data import Dataset
from torchvision import transforms
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
class DatasetSplit(Enum):
TRAIN = "train"
VAL = "val"
TEST = "test"
class AllClassesDataset(Dataset):
def __init__(
self,
source,
input_size=518,
output_size=224,
split=DatasetSplit.TEST,
external_transform=None,
**kwargs,
):
"""
Initialize the dataset to include all classes.
Args:
source (str): Path to the root data directory.
input_size (int): Input image size for transformations.
output_size (int): Output mask size.
split (DatasetSplit): Dataset split to use (TRAIN, VAL, TEST).
external_transform (callable, optional): External image transformations.
**kwargs: Additional keyword arguments.
"""
super().__init__()
self.source = source
self.split = split
self.classnames_to_use = self.get_all_class_names()
self.imgpaths_per_class, self.data_to_iterate = self.get_image_data()
if external_transform is None:
self.transform_img = transforms.Compose([
transforms.Resize((input_size, input_size)),
# transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
else:
self.transform_img = external_transform
self.transform_mask = transforms.Compose([
transforms.Resize((output_size, output_size)),
# transforms.CenterCrop(output_size),
transforms.ToTensor(),
])
self.output_shape = (1, output_size, output_size)
def get_all_class_names(self):
"""
Retrieve all class names (subdirectories) from the source directory.
Returns:
list: List of class names.
"""
all_items = os.listdir(self.source)
classnames = [
item for item in all_items
if os.path.isdir(os.path.join(self.source, item))
]
return classnames
def get_image_data(self):
"""
Collect image paths and corresponding mask paths for all classes.
Returns:
tuple: (imgpaths_per_class, data_to_iterate)
"""
imgpaths_per_class = {}
maskpaths_per_class = {}
for classname in self.classnames_to_use:
classpath = os.path.join(self.source, classname, self.split.value)
maskpath = os.path.join(self.source, classname, "ground_truth")
anomaly_types = os.listdir(classpath)
imgpaths_per_class[classname] = {}
maskpaths_per_class[classname] = {}
for anomaly in anomaly_types:
anomaly_path = os.path.join(classpath, anomaly)
anomaly_files = sorted(os.listdir(anomaly_path))
imgpaths_per_class[classname][anomaly] = [
os.path.join(anomaly_path, x) for x in anomaly_files
]
if self.split == DatasetSplit.TEST and anomaly != "good":
anomaly_mask_path = os.path.join(maskpath, anomaly)
if os.path.exists(anomaly_mask_path):
anomaly_mask_files = sorted(os.listdir(anomaly_mask_path))
maskpaths_per_class[classname][anomaly] = [
os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files
]
else:
# If mask path does not exist, set to None
maskpaths_per_class[classname][anomaly] = [None] * len(anomaly_files)
else:
maskpaths_per_class[classname]["good"] = [None] * len(anomaly_files)
data_to_iterate = []
for classname in sorted(imgpaths_per_class.keys()):
for anomaly in sorted(imgpaths_per_class[classname].keys()):
for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]):
data_tuple = [classname, anomaly, image_path]
if self.split == DatasetSplit.TEST and anomaly != "good":
mask_path = maskpaths_per_class[classname][anomaly][i]
data_tuple.append(mask_path)
else:
data_tuple.append(None)
data_to_iterate.append(data_tuple)
return imgpaths_per_class, data_to_iterate
def __getitem__(self, idx):
classname, anomaly, image_path, mask_path = self.data_to_iterate[idx]
try:
image = PIL.Image.open(image_path).convert("RGB")
except Exception as e:
# Return a black image or handle as per your requirement
image = PIL.Image.new("RGB", (self.transform_img.transforms[0].size, self.transform_img.transforms[0].size), (0, 0, 0))
image = self.transform_img(image)
if self.split == DatasetSplit.TEST and mask_path is not None:
try:
mask = PIL.Image.open(mask_path).convert("L")
mask = self.transform_mask(mask) > 0
except Exception as e:
mask = torch.zeros([*self.output_shape])
else:
mask = torch.zeros([*self.output_shape])
return {
"image": image, # Tensor: [3, H, W]
"mask": mask, # Tensor: [1, 17, 17]
"is_anomaly": int(anomaly != "good"),
"image_path": image_path,
}
def __len__(self):
return len(self.data_to_iterate)