Continual-Mega commited on
Commit
005ccd2
·
1 Parent(s): 44e47d3

Add: dataset continual.py

Browse files
Files changed (1) hide show
  1. dataset/continual.py +119 -0
dataset/continual.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+ from torchvision import transforms
5
+ from PIL import Image, ImageOps
6
+ import numpy as np
7
+ import json
8
+ import cv2
9
+ import albumentations as A
10
+ from albumentations.pytorch import ToTensorV2
11
+
12
+
13
+ import logging
14
+
15
+ logger = logging.getLogger()
16
+
17
+ class ImageDataset(Dataset):
18
+ def __init__(self,
19
+ data_root,
20
+ meta_file="", # meta file path
21
+ resize=240,
22
+ mode="train",
23
+ aug=False,
24
+ with_fg_mask=False,
25
+ test_class='None'
26
+ ):
27
+
28
+ self.data_root = data_root
29
+ self.resize = resize
30
+ self.mode = mode
31
+ self.test_class = test_class
32
+ self.with_fg_mask = with_fg_mask
33
+ self.aug = aug
34
+
35
+ if isinstance(meta_file, str):
36
+ meta_info = json.load(open(meta_file, 'r'))
37
+ else:
38
+ meta_info = meta_file
39
+
40
+ self.data_list = []
41
+ if self.mode == "train":
42
+ meta_info = meta_info[mode]
43
+ for cls_name, data_list in meta_info.items():
44
+ self.data_list.extend(data_list)
45
+ # for data in data_list:
46
+ # if data["anomaly"] == 0:
47
+ # self.data_list.append(data)
48
+ self.class_names = list(meta_info.keys())
49
+ else:
50
+ meta_info = meta_info[mode][test_class]
51
+ self.data_list.extend(meta_info)
52
+ self.class_names = [test_class]
53
+
54
+
55
+ self.resize_img_transform = transforms.Resize((self.resize, self.resize), interpolation=Image.BICUBIC)
56
+ self.resize_mask_transform = transforms.Resize((self.resize, self.resize), interpolation=Image.NEAREST)
57
+ self.aug_transform = A.Compose([
58
+ A.HorizontalFlip(p=0.2),
59
+ A.VerticalFlip(p=0.2),
60
+ A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0, rotate_limit=0, p=0.2),
61
+ # A.Rotate(limit=30, p=0.5),
62
+ ToTensorV2()
63
+ ])
64
+
65
+ def __getitem__(self, idx):
66
+ data = self.data_list[idx]
67
+ img_path, mask_path, cls_name, anomaly = data["img_path"], data["mask_path"], data["cls_name"], data["anomaly"]
68
+
69
+ img_path = os.path.join(self.data_root, img_path)
70
+ mask_path = os.path.join(self.data_root, mask_path)
71
+ if self.with_fg_mask:
72
+ fg_mask_path = img_path.replace(self.data_root, "sam2_fg_mask")
73
+ fg_mask_path = fg_mask_path[:-3] + "png"
74
+
75
+ image = Image.open(img_path).convert('RGB')
76
+ image = ImageOps.exif_transpose(image)
77
+ image = self.resize_img_transform(image)
78
+
79
+ if anomaly == 0:
80
+ mask = Image.fromarray(np.zeros((self.resize, self.resize)), mode='L')
81
+ else:
82
+ mask = np.array(Image.open(mask_path).convert('L')) > 0
83
+ mask = Image.fromarray(mask.astype(np.uint8) * 255, mode='L')
84
+ mask = self.resize_mask_transform(mask)
85
+
86
+ if self.with_fg_mask:
87
+ fg_mask = np.array(Image.open(fg_mask_path).convert('L')) > 0
88
+ fg_mask = Image.fromarray(fg_mask.astype(np.uint8) * 255, mode='L')
89
+ fg_mask = self.resize_mask_transform(fg_mask)
90
+
91
+ else:
92
+ fg_mask = torch.zeros(1, self.resize, self.resize)
93
+
94
+ if self.mode == "train" and self.aug:
95
+ image = np.array(image).astype(np.float32)
96
+ mask = np.array(mask)
97
+ augmented = self.aug_transform(image=image, mask=mask)
98
+ image = augmented['image']
99
+ mask = augmented['mask']
100
+ if self.with_fg_mask:
101
+ fg_mask = np.array(fg_mask)
102
+ fg_mask = self.aug_transform(mask=fg_mask)['mask']
103
+
104
+ else:
105
+ image = transforms.ToTensor()(image)
106
+ mask = transforms.ToTensor()(mask)
107
+ if self.with_fg_mask:
108
+ fg_mask = transforms.ToTensor()(fg_mask)
109
+
110
+ return {"image": image, "mask": mask, "fg_mask": fg_mask, "cls_name": cls_name, "anomaly": anomaly}
111
+
112
+ def __len__(self):
113
+ return len(self.data_list)
114
+
115
+
116
+
117
+ if __name__ == '__main__':
118
+
119
+ ds = ImageDataset(is_train=True)