File size: 4,156 Bytes
95b1715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
from torch.utils.data import Dataset
from PIL import Image
from utils import data_utils
from torchvision import transforms


class ImageDataset(Dataset):
    def __init__(self, root, transform=None):
        self.paths = sorted(data_utils.make_dataset(root))
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        image = Image.open(path).convert("RGB")

        if self.transform:
            image = self.transform(image)
        return image


class CelebaAttributeDataset(Dataset):
    def __init__(self, images_root, attr, transform=None, attributes_root="", use_attr=True):
        self.paths = data_utils.make_dataset(images_root)
        self.transform = transform
        with open(attributes_root, "r") as f:
            lines = f.readlines()

        attr_num = -1
        for i, data_attr in enumerate(lines[1].split(" ")):
            if data_attr.strip() == attr.strip():
                attr_num = i
                break
        assert attr_num > -1, f"Can not find attribute {attr}"

        filtred_paths = []
        for path in self.paths:
            pic_num = int(path.split("/")[-1].replace(".jpg", "").replace(".png", "")) 
            pic_attrs = lines[pic_num + 2].strip().split(" ")
            pic_attrs = pic_attrs[2:]
            if use_attr and pic_attrs[attr_num] == "1" or not use_attr and pic_attrs[attr_num] == "-1":
                filtred_paths.append(path)
        self.paths = sorted(filtred_paths)

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        from_path = self.paths[index]
        image = Image.open(from_path).convert("RGB")

        if self.transform:
            image = self.transform(image)
        return image


class FIDDataset(Dataset):
    def __init__(self, files, transforms=None):
        self.files = files
        self.transforms = transforms

    def __len__(self):
        return len(self.files)

    def __getitem__(self, i):
        file = self.files[i]
        image = file.convert("RGB")

        if self.transforms is not None:
            image = self.transforms(image)

        return image


class MetricsPathsDataset(Dataset):
    def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None, return_path=False, ignore=[]):
        self.pairs = []
        self.paths = []
        self.names = []

        for f in os.listdir(root_path):
            if f not in ignore:
                self.names.append(f)
                image_path = os.path.join(root_path, f)
                gt_path = os.path.join(gt_dir, f)
                if f.endswith(".jpg") or f.endswith(".png"):
                    self.pairs.append([image_path, gt_path.replace(".png", ".jpg"), None])
                    self.paths.append(image_path)
        self.transform = transform
        self.transform_train = transform_train
        self.return_path = return_path

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, index):
        from_path, to_path, _ = self.pairs[index]
        from_im = Image.open(from_path).convert("RGB")
        to_im = Image.open(to_path).convert("RGB")

        if self.transform:
            to_im = self.transform(to_im)
            from_im = self.transform(from_im)

        if not self.return_path:
            return from_im, to_im
        else:
            return from_im, to_im, self.names[index]


class MetricsDataDataset(Dataset):
    def __init__(
        self, paths, target_data, fake_data, transform=None, transform_train=None
    ):
        self.fake_data = fake_data
        self.target_data = target_data
        self.paths = paths
        self.transform = transform
        self.transform_train = transform_train

    def __len__(self):
        return len(self.fake_data)

    def __getitem__(self, index):

        target_im = self.target_data[index]
        fake_im = self.fake_data[index]

        if self.transform:
            fake_im = self.transform(fake_im)
            target_im = self.transform(target_im)

        return target_im, fake_im