File size: 4,138 Bytes
005ccd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image, ImageOps
import numpy as np
import json
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2


import logging

logger = logging.getLogger()

class ImageDataset(Dataset):
    def __init__(self,
                 data_root,
                 meta_file="", # meta file path
                 resize=240,
                 mode="train",
                 aug=False,
                 with_fg_mask=False,
                 test_class='None'
                 ):

        self.data_root = data_root
        self.resize = resize
        self.mode = mode
        self.test_class = test_class
        self.with_fg_mask = with_fg_mask
        self.aug = aug

        if isinstance(meta_file, str):
            meta_info = json.load(open(meta_file, 'r'))
        else:
            meta_info = meta_file

        self.data_list = []
        if self.mode == "train":
            meta_info = meta_info[mode]
            for cls_name, data_list in meta_info.items():
                self.data_list.extend(data_list)
                # for data in data_list:
                #     if data["anomaly"] == 0:
                #         self.data_list.append(data)
            self.class_names = list(meta_info.keys())
        else:
            meta_info = meta_info[mode][test_class]
            self.data_list.extend(meta_info)
            self.class_names = [test_class]

        
        self.resize_img_transform = transforms.Resize((self.resize, self.resize), interpolation=Image.BICUBIC)
        self.resize_mask_transform = transforms.Resize((self.resize, self.resize), interpolation=Image.NEAREST)
        self.aug_transform = A.Compose([
            A.HorizontalFlip(p=0.2),
            A.VerticalFlip(p=0.2),
            A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0, rotate_limit=0, p=0.2),
            # A.Rotate(limit=30, p=0.5),
            ToTensorV2()
        ])

    def __getitem__(self, idx):
        data = self.data_list[idx]
        img_path, mask_path, cls_name, anomaly = data["img_path"], data["mask_path"], data["cls_name"], data["anomaly"]
        
        img_path = os.path.join(self.data_root, img_path)
        mask_path = os.path.join(self.data_root, mask_path)
        if self.with_fg_mask:
            fg_mask_path = img_path.replace(self.data_root, "sam2_fg_mask")
            fg_mask_path = fg_mask_path[:-3] + "png"

        image = Image.open(img_path).convert('RGB')
        image = ImageOps.exif_transpose(image)
        image = self.resize_img_transform(image)

        if anomaly == 0:
            mask = Image.fromarray(np.zeros((self.resize, self.resize)), mode='L')
        else:
            mask = np.array(Image.open(mask_path).convert('L')) > 0
            mask = Image.fromarray(mask.astype(np.uint8) * 255, mode='L')
            mask = self.resize_mask_transform(mask)

        if self.with_fg_mask:
            fg_mask = np.array(Image.open(fg_mask_path).convert('L')) > 0
            fg_mask = Image.fromarray(fg_mask.astype(np.uint8) * 255, mode='L')
            fg_mask = self.resize_mask_transform(fg_mask)
            
        else:
            fg_mask = torch.zeros(1, self.resize, self.resize)

        if self.mode == "train" and self.aug:
            image = np.array(image).astype(np.float32)
            mask = np.array(mask)
            augmented = self.aug_transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
            if self.with_fg_mask:
                fg_mask = np.array(fg_mask)
                fg_mask = self.aug_transform(mask=fg_mask)['mask']

        else:
            image = transforms.ToTensor()(image)
            mask = transforms.ToTensor()(mask)
            if self.with_fg_mask:
                fg_mask = transforms.ToTensor()(fg_mask)

        return {"image": image, "mask": mask, "fg_mask": fg_mask, "cls_name": cls_name, "anomaly": anomaly}

    def __len__(self):
        return len(self.data_list)



if __name__ == '__main__':

    ds = ImageDataset(is_train=True)