File size: 5,091 Bytes
226675b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from torch.utils.data import Dataset
from .transform import *
import albumentations as albu
from PIL import Image
import numpy as np
import os
import torch
import albumentations as A

class BaseDataset(Dataset):
    def __init__(self, transform=None,mode="train"):
        self.mosaic_ratio = 0.25
        self.mode = mode
        self.img_size = (1024,1024)
        aug_list = []
        for k,v in transform.items():
            if v != None:
                aug_list.append(eval(k)(**v))
            else: aug_list.append(eval(k)())

        self.transform = Compose(aug_list)

        self.t1_normalize = A.Compose([
            A.Normalize()
        ])

        self.t2_normalize = A.Compose([
            A.Normalize()
        ])

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, index):
        imgA, imgB, mask, img_id = self.load_img_and_mask(index)
        if len(self.transform.transforms) != 0:
            imgA, imgB, mask = self.transform([imgA, imgB], mask)
        imgA, imgB, mask = np.array(imgA), np.array(imgB), np.array(mask)
        imgA = self.t1_normalize(image=imgA)['image']
        imgB = self.t2_normalize(image=imgB)['image']
        imgA, imgB = [torch.from_numpy(img).permute(2, 0, 1).float() for img in [imgA, imgB]]
        mask = torch.from_numpy(mask).long()
        return imgA, imgB, mask, img_id

    def get_path(self, data_root, imgA_dir, imgB_dir, mask_dir):
        imgA_filename_list = os.listdir(os.path.join(data_root, imgA_dir))
        imgB_filename_list = os.listdir(os.path.join(data_root, imgB_dir))
        mask_filename_list = os.listdir(os.path.join(data_root, mask_dir))
        assert len(imgA_filename_list) == len(mask_filename_list)
        img_ids = [str(id.split('.')[0]) for id in mask_filename_list]
        return img_ids
    
    def load_img_and_mask(self, index):
        img_id = self.file_paths[index]
        imgA_name = os.path.join(self.data_root, self.imgA_dir, img_id + self.img_suffix)
        imgB_name = os.path.join(self.data_root, self.imgB_dir, img_id + self.img_suffix)
        mask_name = os.path.join(self.data_root, self.mask_dir, img_id + self.mask_suffix)
        imgA = Image.open(imgA_name).convert('RGB')
        imgB = Image.open(imgB_name).convert('RGB')
        mask_rgb = Image.open(mask_name).convert('RGB')
        mask = self.rgb2label(mask_rgb)
        return imgA, imgB, mask, img_id

    # def load_mosaic_img_and_mask(self, index):
    #     indexes = [index] + [random.randint(0, len(self.file_paths) - 1) for _ in range(3)]
    #     img_a, mask_a = self.load_img_and_mask(indexes[0])
    #     img_b, mask_b = self.load_img_and_mask(indexes[1])
    #     img_c, mask_c = self.load_img_and_mask(indexes[2])
    #     img_d, mask_d = self.load_img_and_mask(indexes[3])

    #     img_a, mask_a = np.array(img_a), np.array(mask_a)
    #     img_b, mask_b = np.array(img_b), np.array(mask_b)
    #     img_c, mask_c = np.array(img_c), np.array(mask_c)
    #     img_d, mask_d = np.array(img_d), np.array(mask_d)

    #     h = self.img_size[0]
    #     w = self.img_size[1]

    #     start_x = w // 4
    #     strat_y = h // 4
    #     # The coordinates of the splice center
    #     offset_x = random.randint(start_x, (w - start_x))
    #     offset_y = random.randint(strat_y, (h - strat_y))

    #     crop_size_a = (offset_x, offset_y)
    #     crop_size_b = (w - offset_x, offset_y)
    #     crop_size_c = (offset_x, h - offset_y)
    #     crop_size_d = (w - offset_x, h - offset_y)

    #     random_crop_a = albu.RandomCrop(width=crop_size_a[0], height=crop_size_a[1])
    #     random_crop_b = albu.RandomCrop(width=crop_size_b[0], height=crop_size_b[1])
    #     random_crop_c = albu.RandomCrop(width=crop_size_c[0], height=crop_size_c[1])
    #     random_crop_d = albu.RandomCrop(width=crop_size_d[0], height=crop_size_d[1])

    #     croped_a = random_crop_a(image=img_a.copy(), mask=mask_a.copy())
    #     croped_b = random_crop_b(image=img_b.copy(), mask=mask_b.copy())
    #     croped_c = random_crop_c(image=img_c.copy(), mask=mask_c.copy())
    #     croped_d = random_crop_d(image=img_d.copy(), mask=mask_d.copy())

    #     img_crop_a, mask_crop_a = croped_a['image'], croped_a['mask']
    #     img_crop_b, mask_crop_b = croped_b['image'], croped_b['mask']
    #     img_crop_c, mask_crop_c = croped_c['image'], croped_c['mask']
    #     img_crop_d, mask_crop_d = croped_d['image'], croped_d['mask']

    #     top = np.concatenate((img_crop_a, img_crop_b), axis=1)
    #     bottom = np.concatenate((img_crop_c, img_crop_d), axis=1)
    #     img = np.concatenate((top, bottom), axis=0)

    #     top_mask = np.concatenate((mask_crop_a, mask_crop_b), axis=1)
    #     bottom_mask = np.concatenate((mask_crop_c, mask_crop_d), axis=1)
    #     mask = np.concatenate((top_mask, bottom_mask), axis=0)
    #     mask = np.ascontiguousarray(mask)
    #     img = np.ascontiguousarray(img)
    #     img = Image.fromarray(img)
    #     mask = Image.fromarray(mask)
    #     # print(img.shape)

    #     return img, mask