File size: 1,668 Bytes
168ec29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from torch.utils import data
import os
from PIL import Image, ImageFile


ImageFile.LOAD_TRUNCATED_IMAGES = True


class EvalDataset(data.Dataset):
    def __init__(self, pred_root, label_root, return_predpath=False, return_gtpath=False):
        self.return_predpath = return_predpath
        self.return_gtpath = return_gtpath
        pred_dirs = os.listdir(pred_root)
        label_dirs = os.listdir(label_root)

        dir_name_list = []
        for idir in pred_dirs:
            if idir in label_dirs:
                pred_names = os.listdir(os.path.join(pred_root, idir))
                label_names = os.listdir(os.path.join(label_root, idir))
                for iname in pred_names:
                    if iname in label_names:
                        dir_name_list.append(os.path.join(idir, iname))

        self.image_path = list(
            map(lambda x: os.path.join(pred_root, x), dir_name_list))
        self.label_path = list(
            map(lambda x: os.path.join(label_root, x), dir_name_list))

        self.labels = []
        for p in self.label_path:
            self.labels.append(Image.open(p).convert('L'))


    def __getitem__(self, item):
        predpath = self.image_path[item]
        gtpath = self.label_path[item]
        pred = Image.open(predpath).convert('L')
        gt = self.labels[item]
        if pred.size != gt.size:
            pred = pred.resize(gt.size, Image.BILINEAR)
        returns = [pred, gt]
        if self.return_predpath:
            returns.append(predpath)
        if self.return_gtpath:
            returns.append(gtpath)
        return returns

    def __len__(self):
        return len(self.image_path)