File size: 12,370 Bytes
e5461d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
from torchvision import transforms
from .perlin import perlin_mask
from enum import Enum

import numpy as np
import pandas as pd
import logging

LOGGER = logging.getLogger(__name__)
import PIL
import torch
import os
import glob

_CLASSNAMES = [
    "carpet",
    "grid",
    "leather",
    "tile",
    "wood",
    "bottle",
    "cable",
    "capsule",
    "hazelnut",
    "metal_nut",
    "pill",
    "screw",
    "toothbrush",
    "transistor",
    "zipper",
]

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]


class DatasetSplit(Enum):
    TRAIN = "train"
    TEST = "test"


class MVTecDataset(torch.utils.data.Dataset):
    """

    PyTorch Dataset for MVTec.

    """

    def __init__(

            self,

            source,

            anomaly_source_path='/root/dataset/dtd/images',

            dataset_name='mvtec',

            classname='leather',

            resize=288,

            imagesize=288,

            split=DatasetSplit.TRAIN,

            rotate_degrees=0,

            translate=0,

            brightness_factor=0,

            contrast_factor=0,

            saturation_factor=0,

            gray_p=0,

            h_flip_p=0,

            v_flip_p=0,

            distribution=0,

            mean=0.5,

            std=0.1,

            fg=0,

            rand_aug=1,

            scale=0,

            batch_size=8,

            **kwargs,

    ):
        """

        Args:

            source: [str]. Path to the MVTec data folder.

            classname: [str or None]. Name of MVTec class that should be

                       provided in this dataset. If None, the datasets

                       iterates over all available images.

            resize: [int]. (Square) Size the loaded image initially gets

                    resized to.

            imagesize: [int]. (Square) Size the resized loaded image gets

                       (center-)cropped to.

            split: [enum-option]. Indicates if training or test split of the

                   data should be used. Has to be an option taken from

                   DatasetSplit, e.g. mvtec.DatasetSplit.TRAIN. Note that

                   mvtec.DatasetSplit.TEST will also load mask data.

        """
        super().__init__()
        self.source = source
        self.split = split
        self.batch_size = batch_size
        self.distribution = distribution
        self.mean = mean
        self.std = std
        self.fg = fg
        self.rand_aug = rand_aug
        self.resize = resize if self.distribution != 1 else [resize, resize]
        self.imgsize = imagesize
        self.imagesize = (3, self.imgsize, self.imgsize)
        self.classname = classname
        self.dataset_name = dataset_name

        if self.distribution != 1 and (self.classname == 'toothbrush' or self.classname == 'wood'):
            self.resize = round(self.imgsize * 329 / 288)

        xlsx_path = './datasets/excel/' + self.dataset_name + '_distribution.xlsx'
        if self.fg == 2:  # choose by file
            try:
                df = pd.read_excel(xlsx_path)
                self.class_fg = df.loc[df['Class'] == self.dataset_name + '_' + classname, 'Foreground'].values[0]
            except:
                self.class_fg = 1
        elif self.fg == 1:  # with foreground mask
            self.class_fg = 1
        else:  # without foreground mask
            self.class_fg = 0

        self.imgpaths_per_class, self.data_to_iterate = self.get_image_data()
        self.anomaly_source_paths = sorted(1 * glob.glob(anomaly_source_path + "/*/*/*/*.png") +
                                           0 * list(next(iter(self.imgpaths_per_class.values())).values())[0])
        print(self.anomaly_source_paths)
        self.transform_img = [
            transforms.Resize(self.resize),
            transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor),
            transforms.RandomHorizontalFlip(h_flip_p),
            transforms.RandomVerticalFlip(v_flip_p),
            transforms.RandomGrayscale(gray_p),
            transforms.RandomAffine(rotate_degrees,
                                    translate=(translate, translate),
                                    scale=(1.0 - scale, 1.0 + scale),
                                    interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(self.imgsize),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ]
        self.transform_img = transforms.Compose(self.transform_img)

        self.transform_mask = [
            transforms.Resize(self.resize),
            transforms.CenterCrop(self.imgsize),
            transforms.ToTensor(),
        ]
        self.transform_mask = transforms.Compose(self.transform_mask)

    def rand_augmenter(self):
        list_aug = [
            transforms.ColorJitter(contrast=(0.8, 1.2)),
            transforms.ColorJitter(brightness=(0.8, 1.2)),
            transforms.ColorJitter(saturation=(0.8, 1.2), hue=(-0.2, 0.2)),
            transforms.RandomHorizontalFlip(p=1),
            transforms.RandomVerticalFlip(p=1),
            transforms.RandomGrayscale(p=1),
            transforms.RandomAutocontrast(p=1),
            transforms.RandomEqualize(p=1),
            transforms.RandomAffine(degrees=(-45, 45)),
        ]
        aug_idx = np.random.choice(np.arange(len(list_aug)), 3, replace=False)

        transform_aug = [
            transforms.Resize(self.resize),
            list_aug[aug_idx[0]],
            list_aug[aug_idx[1]],
            list_aug[aug_idx[2]],
            transforms.CenterCrop(self.imgsize),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ]

        transform_aug = transforms.Compose(transform_aug)
        return transform_aug

    def __getitem__(self, idx):
        try:
            classname, anomaly, image_path, mask_path = self.data_to_iterate[idx]

            # Load the main image
            if not os.path.exists(image_path):
                LOGGER.warning(f"Image not found: {image_path}. Skipping index {idx}.")
                return None

            image = PIL.Image.open(image_path).convert("RGB")
            image = self.transform_img(image)

            # Initialize default tensors
            mask_fg = mask_s = aug_image = torch.tensor([1])

            if self.split == DatasetSplit.TRAIN:
                try:
                    aug = PIL.Image.open(np.random.choice(self.anomaly_source_paths)).convert("RGB")
                    if self.rand_aug:
                        transform_aug = self.rand_augmenter()
                        aug = transform_aug(aug)
                    else:
                        aug = self.transform_img(aug)
                except IndexError:
                    LOGGER.warning(f"No anomaly source images available. Using original image as augmentation for index {idx}.")
                    aug = image  # Use original image if no anomaly source images

                # Handle foreground mask
                if self.class_fg:
                    fgmask_path = (
                        image_path.split(classname)[0]
                        + classname
                        + "/ground_truth/"
                        + os.path.split(image_path)[-1].replace(".png", "_mask.png")
                    )
                    if os.path.exists(fgmask_path):
                        mask_fg = PIL.Image.open(fgmask_path)
                        mask_fg = torch.ceil(self.transform_mask(mask_fg)[0])
                    else:
                        LOGGER.warning(f"Foreground mask not found: {fgmask_path}. Skipping mask for index {idx}.")
                        mask_fg = torch.zeros_like(image[0])  # Default empty mask

                # Generate masks and augmented images
                mask_all = perlin_mask(image.shape, self.imgsize // 8, 0, 6, mask_fg, 1)
                mask_s = torch.from_numpy(mask_all[0])
                mask_l = torch.from_numpy(mask_all[1])

                beta = np.random.normal(loc=self.mean, scale=self.std)
                beta = np.clip(beta, 0.2, 0.8)
                aug_image = image * (1 - mask_l) + (1 - beta) * aug * mask_l + beta * image * mask_l

            if self.split == DatasetSplit.TEST and mask_path is not None:
                if os.path.exists(mask_path):
                    mask_gt = PIL.Image.open(mask_path).convert("L")
                    mask_gt = self.transform_mask(mask_gt)
                else:
                    LOGGER.warning(f"Ground truth mask not found: {mask_path}. Using default empty mask for index {idx}.")
                    mask_gt = torch.zeros([1, *image.size()[1:]])
            else:
                mask_gt = torch.zeros([1, *image.size()[1:]])

            return {
                "image": image,
                "aug": aug_image,
                "mask_s": mask_s,
                "mask_gt": mask_gt,
                "is_anomaly": int(anomaly != "good"),
                "image_path": image_path,
            }

        except Exception as e:
            LOGGER.error(f"Error processing index {idx}: {e}")
            return None


    def __len__(self):
        return len(self.data_to_iterate)

    def get_image_data(self):
        imgpaths_per_class = {}
        maskpaths_per_class = {}

        classpath = os.path.join(self.source, self.classname, self.split.value)
        maskpath = os.path.join(self.source, self.classname, "ground_truth")
        anomaly_types = os.listdir(classpath)

        imgpaths_per_class[self.classname] = {}
        maskpaths_per_class[self.classname] = {}

        for anomaly in anomaly_types:
            anomaly_path = os.path.join(classpath, anomaly)
            anomaly_files = sorted(os.listdir(anomaly_path))
            imgpaths_per_class[self.classname][anomaly] = [os.path.join(anomaly_path, x) for x in anomaly_files]

            if self.split == DatasetSplit.TEST and anomaly != "good":
                anomaly_mask_path = os.path.join(maskpath, anomaly)
                if os.path.exists(anomaly_mask_path):
                    anomaly_mask_files = sorted(os.listdir(anomaly_mask_path))
                    maskpaths_per_class[self.classname][anomaly] = [os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files]
                else:
                    LOGGER.warning(f"Anomaly mask path does not exist: {anomaly_mask_path}. Skipping masks for {anomaly}.")
                    maskpaths_per_class[self.classname][anomaly] = []
            else:
                maskpaths_per_class[self.classname]["good"] = None

        data_to_iterate = []
        for classname in sorted(imgpaths_per_class.keys()):
            for anomaly in sorted(imgpaths_per_class[classname].keys()):
                for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]):
                    try:
                        if self.split == DatasetSplit.TEST and anomaly != "good":
                            if i < len(maskpaths_per_class[classname][anomaly]):
                                mask_path = maskpaths_per_class[classname][anomaly][i]
                            else:
                                LOGGER.warning(f"No corresponding mask for {image_path}. Skipping.")
                                continue
                        else:
                            mask_path = None

                        if os.path.exists(image_path) and (mask_path is None or os.path.exists(mask_path)):
                            data_to_iterate.append([classname, anomaly, image_path, mask_path])
                        else:
                            LOGGER.warning(f"Missing required file for {image_path} or {mask_path}. Skipping.")
                    except Exception as e:
                        LOGGER.error(f"Error processing file {image_path}: {e}. Skipping.")

        if len(data_to_iterate) == 0:
            raise ValueError("No valid data found. Please check dataset paths and files.")

        return imgpaths_per_class, data_to_iterate