JohnWeck commited on
Commit
7ce6244
·
verified ·
1 Parent(s): e27420f

Upload 2 files

Browse files
dataset/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (3.24 kB). View file
 
dataset/dataset.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ import numpy as np
4
+ import random
5
+ import os
6
+ import glob
7
+ import torch.nn.functional as F
8
+ from PIL import Image
9
+ from skimage.color import rgb2gray
10
+
11
+
12
+ def map_to_classes(label_array, max_pixel):
13
+ return np.clip(np.round(label_array * (max_pixel)), 0, max_pixel).astype(np.uint8)
14
+
15
+
16
+ def map_to_classes_isic(label_array):
17
+ image = np.where(label_array >= 0.5, 1, 0)
18
+ image = (image * 255.0).astype('uint8')
19
+ return image
20
+
21
+
22
+ def map_to_classes2(label_array):
23
+ image = np.where(label_array >= 0.5, 1, 0).astype('uint8')
24
+ return image
25
+
26
+
27
+ def center_crop(image, crop_size):
28
+ height, width = image.shape[:2]
29
+ crop_height, crop_width = crop_size
30
+ start_y = (height - crop_height) // 2
31
+ start_x = (width - crop_width) // 2
32
+
33
+ cropped_image = image[start_y:start_y + crop_height, start_x:start_x + crop_width]
34
+
35
+ return cropped_image
36
+
37
+
38
+ class MyDataset(torch.utils.data.Dataset):
39
+
40
+ def __init__(self, root, tokenizer, size=256, center_crop=True, t_drop_rate=0.05,
41
+ i_drop_rate=0.05, ti_drop_rate=0.05):
42
+ super().__init__()
43
+
44
+ self.tokenizer = tokenizer
45
+ self.size = size
46
+ self.center_crop = center_crop
47
+ self.i_drop_rate = i_drop_rate
48
+ self.t_drop_rate = t_drop_rate
49
+ self.ti_drop_rate = ti_drop_rate
50
+
51
+ self.data = glob.glob(os.path.join(root, '*', '*.npz'))
52
+
53
+ self.img_transform = transforms.Compose([
54
+ transforms.ToTensor(),
55
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
56
+ ])
57
+
58
+ self.mask_transform = transforms.Compose([
59
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
60
+ ])
61
+
62
+ self.max_pixels = {
63
+ 'AMOS2022': 15,
64
+ 'ACDC': 3,
65
+ 'BUSI': 1,
66
+ 'CVC-ClinicDB': 1,
67
+ 'kvasir-seg': 1,
68
+ 'LiTS2017': 2,
69
+ 'KiTS2019': 2,
70
+ }
71
+
72
+ self.AMOS2022 = {1:'liver',2:'right kidney',3:'spleen',4:'pancreas',5:'aorta',6:'inferior vena cava',7:'right adrenal gland',8:'left adrenal gland',
73
+ 9:'gall bladder',10:'esophagus',11:'stomach',12:'duodenum',13:'left kidney',14:'bladder',15:'prostate'}
74
+ self.ACDC = {1:'right ventricle',2:'myocardium',3:'left ventricle'}
75
+ self.LiTS2017 = {1:'liver',2:'liver tumor'}
76
+ self.KiTS2019 = {1:'kidney',2:'kidney tumor'}
77
+
78
+ self.aspect_ratios = [
79
+ (16, 9), # 16:9
80
+ (4, 3), # 4:3
81
+ (3, 2), # 3:2
82
+ (1, 1), # 1:1
83
+ (2, 1), # 2:1
84
+ (9, 16), # 9:16
85
+ (5, 4), # 5:4
86
+ (3, 4), # 3:4
87
+ (2, 3) # 2:3
88
+ ]
89
+
90
+ def get_target_size(self, aspect_ratio, max_size=512):
91
+ h_ratio, w_ratio = aspect_ratio
92
+ if h_ratio > w_ratio:
93
+ height = max_size
94
+ # print(w_ratio, h_ratio)
95
+ width = int(max_size * w_ratio / h_ratio)
96
+ else:
97
+ width = max_size
98
+ height = int(max_size * h_ratio / w_ratio)
99
+
100
+ return (height, width)
101
+
102
+ def convert_to_rgb(self, image):
103
+ if len(image.shape) == 2:
104
+ rgb_img = np.stack((image, image, image), axis=-1)
105
+ elif len(image.shape) == 3 and image.shape[2] == 3:
106
+ rgb_img = image
107
+ else:
108
+ raise ValueError("不支持的图像格式")
109
+
110
+ return rgb_img
111
+
112
+ def __getitem__(self, idx):
113
+ path = self.data[idx]
114
+ name = path.split('/')[-2]
115
+
116
+ # read image
117
+
118
+ raw_image, ori_raw_mask = np.load(path)['image'], np.load(path)['label']
119
+ kinds = np.unique(ori_raw_mask)
120
+ raw_image, raw_mask = self.convert_to_rgb(raw_image), self.convert_to_rgb(ori_raw_mask)
121
+
122
+ # original size
123
+ # aspect = self.aspect_ratios[random.randint(0, len(self.aspect_ratios) - 1)]
124
+ # shape = self.get_target_size(aspect, self.size)
125
+
126
+ image_tensor = self.img_transform(raw_image)
127
+ raw_mask = raw_mask / self.max_pixels[name]
128
+ raw_mask = torch.from_numpy(raw_mask.transpose((2, 0, 1))).contiguous()
129
+ mask_tensor = self.mask_transform(raw_mask)
130
+ # image_tensor = transforms.Resize(size=shape)(image_tensor)
131
+ # mask_tensor = transforms.Resize(size=shape)(mask_tensor)
132
+
133
+ image = image_tensor.squeeze(dim=0)
134
+ mask = mask_tensor.squeeze(dim=0)
135
+
136
+ organ, kind = '', ''
137
+ tips = []
138
+ if name == 'AMOS2022':
139
+ organ = 'abdomen CT scans'
140
+ for k in kinds:
141
+ if k == 0:
142
+ pass
143
+ else:
144
+ tips.append(self.AMOS2022[k])
145
+
146
+ if len(tips) != 0:
147
+ random.shuffle(tips)
148
+ for tip in tips:
149
+ if kind == '':
150
+ kind = tip
151
+ else:
152
+ kind = kind + ',' + tip
153
+
154
+ elif name == 'ACDC':
155
+ organ = 'cardiovascular ventricle mri'
156
+ for k in kinds:
157
+ if k == 0:
158
+ pass
159
+ else:
160
+ tips.append(self.ACDC[k])
161
+
162
+ if len(tips) != 0:
163
+ random.shuffle(tips)
164
+ for tip in tips:
165
+ if kind == '':
166
+ kind = tip
167
+ else:
168
+ kind = kind + ',' + tip
169
+
170
+ elif name == 'BUSI':
171
+ organ = 'breast ultrasound'
172
+ if not kinds.any():
173
+ kind = 'normal'
174
+ else:
175
+ kind = 'breast tumor'
176
+
177
+ elif name == 'CVC-ClinicDB':
178
+ organ = 'polyp colonoscopy'
179
+ if not kinds.any():
180
+ kind = 'normal'
181
+ else:
182
+ kind = 'polyp'
183
+
184
+ elif name == 'kvasir-seg':
185
+ organ = 'polyp colonoscopy'
186
+ if not kinds.any():
187
+ kind = 'normal'
188
+ else:
189
+ kind = 'polyp'
190
+
191
+ elif name == 'LiTS2017':
192
+ organ = 'abdomen CT scans'
193
+ for k in kinds:
194
+ if k == 0:
195
+ pass
196
+ else:
197
+ tips.append(self.LiTS2017[k])
198
+
199
+ if len(tips) != 0:
200
+ random.shuffle(tips)
201
+ for tip in tips:
202
+ if kind == '':
203
+ kind = tip
204
+ else:
205
+ kind = kind + ',' + tip
206
+
207
+ elif name == 'KiTS2019':
208
+ organ = 'abdomen CT scans'
209
+ for k in kinds:
210
+ if k == 0:
211
+ pass
212
+ else:
213
+ tips.append(self.KiTS2019[k])
214
+
215
+ if len(tips) != 0:
216
+ random.shuffle(tips)
217
+ for tip in tips:
218
+ if kind == '':
219
+ kind = tip
220
+ else:
221
+ kind = kind + ',' + tip
222
+
223
+ if kind == '':
224
+ kind = 'no found'
225
+
226
+ img_text = f'a photo of {organ} image, with {kind}.'
227
+ mask_text = f'a photo of {organ} label, with {kind}.'
228
+
229
+ # if name == 'LiTS2017':
230
+ # print(kinds, img_text)
231
+
232
+ # drop
233
+ rand_num = random.random()
234
+ if rand_num < self.i_drop_rate:
235
+ img_text = ""
236
+ elif rand_num < (self.i_drop_rate + self.t_drop_rate):
237
+ mask_text = ""
238
+ elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
239
+ img_text = ""
240
+ mask_text = ""
241
+
242
+ # get text and tokenize
243
+ img_text_input_ids = self.tokenizer(
244
+ img_text,
245
+ max_length=self.tokenizer.model_max_length,
246
+ padding="max_length",
247
+ truncation=True,
248
+ return_tensors="pt"
249
+ ).input_ids
250
+
251
+ mask_text_input_ids = self.tokenizer(
252
+ mask_text,
253
+ max_length=self.tokenizer.model_max_length,
254
+ padding="max_length",
255
+ truncation=True,
256
+ return_tensors="pt"
257
+ ).input_ids
258
+
259
+ return {
260
+ "image": image,
261
+ "mask": mask,
262
+ "img_text_input_ids": img_text_input_ids,
263
+ "mask_text_input_ids": mask_text_input_ids,
264
+ "raw_mask": ori_raw_mask,
265
+ "kind": kind
266
+ }
267
+
268
+ def __len__(self):
269
+ return len(self.data)
270
+
271
+
272
+ def collate_fn(data):
273
+
274
+ aspect_ratios = [
275
+ (16, 9), # 16:9
276
+ (4, 3), # 4:3
277
+ (3, 2), # 3:2
278
+ (1, 1), # 1:1
279
+ (2, 1), # 2:1
280
+ (9, 16), # 9:16
281
+ (5, 4), # 5:4
282
+ (3, 4), # 3:4
283
+ (2, 3) # 2:3
284
+ ]
285
+
286
+ def get_target_size(aspect_ratio, max_size=256):
287
+ h_ratio, w_ratio = aspect_ratio
288
+ if h_ratio > w_ratio:
289
+ height = max_size
290
+ # print(w_ratio, h_ratio)
291
+ width = int(max_size * w_ratio / h_ratio)
292
+ else:
293
+ width = max_size
294
+ height = int(max_size * h_ratio / w_ratio)
295
+
296
+ return (height, width)
297
+
298
+ aspect = aspect_ratios[random.randint(0, len(aspect_ratios) - 1)]
299
+ shape = get_target_size(aspect, 512)
300
+
301
+ images = torch.stack([transforms.Resize(size=shape)(example["image"]) for example in data])
302
+ masks = torch.stack([transforms.Resize(size=shape)(example["mask"]) for example in data])
303
+ img_text_input_ids = torch.cat([example["img_text_input_ids"] for example in data], dim=0)
304
+ mask_text_input_ids = torch.cat([example["mask_text_input_ids"] for example in data], dim=0)
305
+
306
+ return {
307
+ "images": images,
308
+ "masks": masks,
309
+ "img_text_input_ids": img_text_input_ids,
310
+ "mask_text_input_ids": mask_text_input_ids,
311
+ }
312
+