Matharrr commited on
Commit
c1ed8bf
·
1 Parent(s): 1d92877

first commit

Browse files
data/__init__.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package includes all the modules related to data loading and preprocessing
2
+
3
+ To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4
+ You need to implement four functions:
5
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6
+ -- <__len__>: return the size of dataset.
7
+ -- <__getitem__>: get a data point from data loader.
8
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
9
+
10
+ Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11
+ See our template dataset class 'template_dataset.py' for more details.
12
+ """
13
+ import importlib
14
+ import torch.utils.data
15
+ from data.base_dataset import BaseDataset
16
+
17
+
18
+ def find_dataset_using_name(dataset_name):
19
+ """Import the module "data/[dataset_name]_dataset.py".
20
+
21
+ In the file, the class called DatasetNameDataset() will
22
+ be instantiated. It has to be a subclass of BaseDataset,
23
+ and it is case-insensitive.
24
+ """
25
+ dataset_filename = "data." + dataset_name + "_dataset"
26
+ datasetlib = importlib.import_module(dataset_filename)
27
+
28
+ dataset = None
29
+ target_dataset_name = dataset_name.replace('_', '') + 'dataset'
30
+ for name, cls in datasetlib.__dict__.items():
31
+ if name.lower() == target_dataset_name.lower() \
32
+ and issubclass(cls, BaseDataset):
33
+ dataset = cls
34
+
35
+ if dataset is None:
36
+ raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
37
+
38
+ return dataset
39
+
40
+
41
+ def get_option_setter(dataset_name):
42
+ """Return the static method <modify_commandline_options> of the dataset class."""
43
+ dataset_class = find_dataset_using_name(dataset_name)
44
+ return dataset_class.modify_commandline_options
45
+
46
+
47
+ def create_dataset(opt):
48
+ """Create a dataset given the option.
49
+
50
+ This function wraps the class CustomDatasetDataLoader.
51
+ This is the main interface between this package and 'train.py'/'test.py'
52
+
53
+ Example:
54
+ >>> from data import create_dataset
55
+ >>> dataset = create_dataset(opt)
56
+ """
57
+ data_loader = CustomDatasetDataLoader(opt)
58
+ dataset = data_loader.load_data()
59
+ return dataset
60
+
61
+
62
+ class CustomDatasetDataLoader():
63
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
64
+
65
+ def __init__(self, opt):
66
+ """Initialize this class
67
+
68
+ Step 1: create a dataset instance given the name [dataset_mode]
69
+ Step 2: create a multi-threaded data loader.
70
+ """
71
+ self.opt = opt
72
+ dataset_class = find_dataset_using_name(opt.dataset_mode)
73
+ self.dataset = dataset_class(opt)
74
+ print("dataset [%s] was created" % type(self.dataset).__name__)
75
+ self.dataloader = torch.utils.data.DataLoader(
76
+ self.dataset,
77
+ batch_size=opt.batch_size,
78
+ shuffle=not opt.serial_batches,
79
+ num_workers=int(opt.num_threads),
80
+ drop_last=True if opt.isTrain else False,
81
+ )
82
+
83
+ def set_epoch(self, epoch):
84
+ self.dataset.current_epoch = epoch
85
+
86
+ def load_data(self):
87
+ return self
88
+
89
+ def __len__(self):
90
+ """Return the number of data in the dataset"""
91
+ return min(len(self.dataset), self.opt.max_dataset_size)
92
+
93
+ def __iter__(self):
94
+ """Return a batch of data"""
95
+ for i, data in enumerate(self.dataloader):
96
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
97
+ break
98
+ yield data
data/base_dataset.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2
+
3
+ It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4
+ """
5
+ import random
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+ from PIL import Image
9
+ import torchvision.transforms as transforms
10
+ from abc import ABC, abstractmethod
11
+
12
+
13
+ class BaseDataset(data.Dataset, ABC):
14
+ """This class is an abstract base class (ABC) for datasets.
15
+
16
+ To create a subclass, you need to implement the following four functions:
17
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18
+ -- <__len__>: return the size of dataset.
19
+ -- <__getitem__>: get a data point.
20
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ """Initialize the class; save the options in the class
25
+
26
+ Parameters:
27
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28
+ """
29
+ self.opt = opt
30
+ self.root = opt.dataroot
31
+ self.current_epoch = 0
32
+
33
+ @staticmethod
34
+ def modify_commandline_options(parser, is_train):
35
+ """Add new dataset-specific options, and rewrite default values for existing options.
36
+
37
+ Parameters:
38
+ parser -- original option parser
39
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
40
+
41
+ Returns:
42
+ the modified parser.
43
+ """
44
+ return parser
45
+
46
+ @abstractmethod
47
+ def __len__(self):
48
+ """Return the total number of images in the dataset."""
49
+ return 0
50
+
51
+ @abstractmethod
52
+ def __getitem__(self, index):
53
+ """Return a data point and its metadata information.
54
+
55
+ Parameters:
56
+ index - - a random integer for data indexing
57
+
58
+ Returns:
59
+ a dictionary of data with their names. It ususally contains the data itself and its metadata information.
60
+ """
61
+ pass
62
+
63
+
64
+ def get_params(opt, size):
65
+ w, h = size
66
+ new_h = h
67
+ new_w = w
68
+ if opt.preprocess == 'resize_and_crop':
69
+ new_h = new_w = opt.load_size
70
+ elif opt.preprocess == 'scale_width_and_crop':
71
+ new_w = opt.load_size
72
+ new_h = opt.load_size * h // w
73
+
74
+ x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
75
+ y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
76
+
77
+ flip = random.random() > 0.5
78
+
79
+ return {'crop_pos': (x, y), 'flip': flip}
80
+
81
+
82
+ def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
83
+ transform_list = []
84
+ if grayscale:
85
+ transform_list.append(transforms.Grayscale(1))
86
+ if 'fixsize' in opt.preprocess:
87
+ transform_list.append(transforms.Resize(params["size"], method))
88
+ if 'resize' in opt.preprocess:
89
+ osize = [opt.load_size, opt.load_size]
90
+ if "gta2cityscapes" in opt.dataroot:
91
+ osize[0] = opt.load_size // 2
92
+ transform_list.append(transforms.Resize(osize, method))
93
+ elif 'scale_width' in opt.preprocess:
94
+ transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))
95
+ elif 'scale_shortside' in opt.preprocess:
96
+ transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, opt.crop_size, method)))
97
+
98
+ if 'zoom' in opt.preprocess:
99
+ if params is None:
100
+ transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method)))
101
+ else:
102
+ transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method, factor=params["scale_factor"])))
103
+
104
+ if 'crop' in opt.preprocess:
105
+ if params is None or 'crop_pos' not in params:
106
+ transform_list.append(transforms.RandomCrop(opt.crop_size))
107
+ else:
108
+ transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
109
+
110
+ if 'patch' in opt.preprocess:
111
+ transform_list.append(transforms.Lambda(lambda img: __patch(img, params['patch_index'], opt.crop_size)))
112
+
113
+ if 'trim' in opt.preprocess:
114
+ transform_list.append(transforms.Lambda(lambda img: __trim(img, opt.crop_size)))
115
+
116
+ # if opt.preprocess == 'none':
117
+ transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
118
+
119
+ if not opt.no_flip:
120
+ if params is None or 'flip' not in params:
121
+ transform_list.append(transforms.RandomHorizontalFlip())
122
+ elif 'flip' in params:
123
+ transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
124
+
125
+ if convert:
126
+ transform_list += [transforms.ToTensor()]
127
+ if grayscale:
128
+ transform_list += [transforms.Normalize((0.5,), (0.5,))]
129
+ else:
130
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
131
+ return transforms.Compose(transform_list)
132
+
133
+
134
+ def __make_power_2(img, base, method=Image.BICUBIC):
135
+ ow, oh = img.size
136
+ h = int(round(oh / base) * base)
137
+ w = int(round(ow / base) * base)
138
+ if h == oh and w == ow:
139
+ return img
140
+
141
+ return img.resize((w, h), method)
142
+
143
+
144
+ def __random_zoom(img, target_width, crop_width, method=Image.BICUBIC, factor=None):
145
+ if factor is None:
146
+ zoom_level = np.random.uniform(0.8, 1.0, size=[2])
147
+ else:
148
+ zoom_level = (factor[0], factor[1])
149
+ iw, ih = img.size
150
+ zoomw = max(crop_width, iw * zoom_level[0])
151
+ zoomh = max(crop_width, ih * zoom_level[1])
152
+ img = img.resize((int(round(zoomw)), int(round(zoomh))), method)
153
+ return img
154
+
155
+
156
+ def __scale_shortside(img, target_width, crop_width, method=Image.BICUBIC):
157
+ ow, oh = img.size
158
+ shortside = min(ow, oh)
159
+ if shortside >= target_width:
160
+ return img
161
+ else:
162
+ scale = target_width / shortside
163
+ return img.resize((round(ow * scale), round(oh * scale)), method)
164
+
165
+
166
+ def __trim(img, trim_width):
167
+ ow, oh = img.size
168
+ if ow > trim_width:
169
+ xstart = np.random.randint(ow - trim_width)
170
+ xend = xstart + trim_width
171
+ else:
172
+ xstart = 0
173
+ xend = ow
174
+ if oh > trim_width:
175
+ ystart = np.random.randint(oh - trim_width)
176
+ yend = ystart + trim_width
177
+ else:
178
+ ystart = 0
179
+ yend = oh
180
+ return img.crop((xstart, ystart, xend, yend))
181
+
182
+
183
+ def __scale_width(img, target_width, crop_width, method=Image.BICUBIC):
184
+ ow, oh = img.size
185
+ if ow == target_width and oh >= crop_width:
186
+ return img
187
+ w = target_width
188
+ h = int(max(target_width * oh / ow, crop_width))
189
+ return img.resize((w, h), method)
190
+
191
+
192
+ def __crop(img, pos, size):
193
+ ow, oh = img.size
194
+ x1, y1 = pos
195
+ tw = th = size
196
+ if (ow > tw or oh > th):
197
+ return img.crop((x1, y1, x1 + tw, y1 + th))
198
+ return img
199
+
200
+
201
+ def __patch(img, index, size):
202
+ ow, oh = img.size
203
+ nw, nh = ow // size, oh // size
204
+ roomx = ow - nw * size
205
+ roomy = oh - nh * size
206
+ startx = np.random.randint(int(roomx) + 1)
207
+ starty = np.random.randint(int(roomy) + 1)
208
+
209
+ index = index % (nw * nh)
210
+ ix = index // nh
211
+ iy = index % nh
212
+ gridx = startx + ix * size
213
+ gridy = starty + iy * size
214
+ return img.crop((gridx, gridy, gridx + size, gridy + size))
215
+
216
+
217
+ def __flip(img, flip):
218
+ if flip:
219
+ return img.transpose(Image.FLIP_LEFT_RIGHT)
220
+ return img
221
+
222
+
223
+ def __print_size_warning(ow, oh, w, h):
224
+ """Print warning information about image size(only print once)"""
225
+ if not hasattr(__print_size_warning, 'has_printed'):
226
+ print("The image size needs to be a multiple of 4. "
227
+ "The loaded image size was (%d, %d), so it was adjusted to "
228
+ "(%d, %d). This adjustment will be done to all images "
229
+ "whose sizes are not multiples of 4" % (ow, oh, w, h))
230
+ __print_size_warning.has_printed = True
data/image_folder.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A modified image folder class
2
+
3
+ We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4
+ so that this class can load images from both current directory and its subdirectories.
5
+ """
6
+
7
+ import torch.utils.data as data
8
+
9
+ from PIL import Image
10
+ import os
11
+ import os.path
12
+
13
+ IMG_EXTENSIONS = [
14
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
15
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16
+ '.tif', '.TIF', '.tiff', '.TIFF',
17
+ ]
18
+
19
+
20
+ def is_image_file(filename):
21
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22
+
23
+
24
+ def make_dataset(dir, max_dataset_size=float("inf")):
25
+ images = []
26
+ assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
27
+
28
+ for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
29
+ for fname in fnames:
30
+ if is_image_file(fname):
31
+ path = os.path.join(root, fname)
32
+ images.append(path)
33
+ return images[:min(max_dataset_size, len(images))]
34
+
35
+
36
+ def default_loader(path):
37
+ return Image.open(path).convert('RGB')
38
+
39
+
40
+ class ImageFolder(data.Dataset):
41
+
42
+ def __init__(self, root, transform=None, return_paths=False,
43
+ loader=default_loader):
44
+ imgs = make_dataset(root)
45
+ if len(imgs) == 0:
46
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
47
+ "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
48
+
49
+ self.root = root
50
+ self.imgs = imgs
51
+ self.transform = transform
52
+ self.return_paths = return_paths
53
+ self.loader = loader
54
+
55
+ def __getitem__(self, index):
56
+ path = self.imgs[index]
57
+ img = self.loader(path)
58
+ if self.transform is not None:
59
+ img = self.transform(img)
60
+ if self.return_paths:
61
+ return img, path
62
+ else:
63
+ return img
64
+
65
+ def __len__(self):
66
+ return len(self.imgs)
data/single_dataset.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data.base_dataset import BaseDataset, get_transform
2
+ from data.image_folder import make_dataset
3
+ from PIL import Image
4
+
5
+
6
+ class SingleDataset(BaseDataset):
7
+ """This dataset class can load a set of images specified by the path --dataroot /path/to/data.
8
+
9
+ It can be used for generating CycleGAN results only for one side with the model option '-model test'.
10
+ """
11
+
12
+ def __init__(self, opt):
13
+ """Initialize this dataset class.
14
+
15
+ Parameters:
16
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
17
+ """
18
+ BaseDataset.__init__(self, opt)
19
+ self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size))
20
+ input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
21
+ self.transform = get_transform(opt, grayscale=(input_nc == 1))
22
+
23
+ def __getitem__(self, index):
24
+ """Return a data point and its metadata information.
25
+
26
+ Parameters:
27
+ index - - a random integer for data indexing
28
+
29
+ Returns a dictionary that contains A and A_paths
30
+ A(tensor) - - an image in one domain
31
+ A_paths(str) - - the path of the image
32
+ """
33
+ A_path = self.A_paths[index]
34
+ A_img = Image.open(A_path).convert('RGB')
35
+ A = self.transform(A_img)
36
+ return {'A': A, 'A_paths': A_path}
37
+
38
+ def __len__(self):
39
+ """Return the total number of images in the dataset."""
40
+ return len(self.A_paths)
data/singleimage_dataset.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os.path
3
+ from data.base_dataset import BaseDataset, get_transform
4
+ from data.image_folder import make_dataset
5
+ from PIL import Image
6
+ import random
7
+ import util.util as util
8
+
9
+
10
+ class SingleImageDataset(BaseDataset):
11
+ """
12
+ This dataset class can load unaligned/unpaired datasets.
13
+
14
+ It requires two directories to host training images from domain A '/path/to/data/trainA'
15
+ and from domain B '/path/to/data/trainB' respectively.
16
+ You can train the model with the dataset flag '--dataroot /path/to/data'.
17
+ Similarly, you need to prepare two directories:
18
+ '/path/to/data/testA' and '/path/to/data/testB' during test time.
19
+ """
20
+
21
+ def __init__(self, opt):
22
+ """Initialize this dataset class.
23
+
24
+ Parameters:
25
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
26
+ """
27
+ BaseDataset.__init__(self, opt)
28
+
29
+ self.dir_A = os.path.join(opt.dataroot, 'trainA') # create a path '/path/to/data/trainA'
30
+ self.dir_B = os.path.join(opt.dataroot, 'trainB') # create a path '/path/to/data/trainB'
31
+
32
+ if os.path.exists(self.dir_A) and os.path.exists(self.dir_B):
33
+ self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
34
+ self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
35
+ self.A_size = len(self.A_paths) # get the size of dataset A
36
+ self.B_size = len(self.B_paths) # get the size of dataset B
37
+
38
+ assert len(self.A_paths) == 1 and len(self.B_paths) == 1,\
39
+ "SingleImageDataset class should be used with one image in each domain"
40
+ A_img = Image.open(self.A_paths[0]).convert('RGB')
41
+ B_img = Image.open(self.B_paths[0]).convert('RGB')
42
+ print("Image sizes %s and %s" % (str(A_img.size), str(B_img.size)))
43
+
44
+ self.A_img = A_img
45
+ self.B_img = B_img
46
+
47
+ # In single-image translation, we augment the data loader by applying
48
+ # random scaling. Still, we design the data loader such that the
49
+ # amount of scaling is the same within a minibatch. To do this,
50
+ # we precompute the random scaling values, and repeat them by |batch_size|.
51
+ A_zoom = 1 / self.opt.random_scale_max
52
+ zoom_levels_A = np.random.uniform(A_zoom, 1.0, size=(len(self) // opt.batch_size + 1, 1, 2))
53
+ self.zoom_levels_A = np.reshape(np.tile(zoom_levels_A, (1, opt.batch_size, 1)), [-1, 2])
54
+
55
+ B_zoom = 1 / self.opt.random_scale_max
56
+ zoom_levels_B = np.random.uniform(B_zoom, 1.0, size=(len(self) // opt.batch_size + 1, 1, 2))
57
+ self.zoom_levels_B = np.reshape(np.tile(zoom_levels_B, (1, opt.batch_size, 1)), [-1, 2])
58
+
59
+ # While the crop locations are randomized, the negative samples should
60
+ # not come from the same location. To do this, we precompute the
61
+ # crop locations with no repetition.
62
+ self.patch_indices_A = list(range(len(self)))
63
+ random.shuffle(self.patch_indices_A)
64
+ self.patch_indices_B = list(range(len(self)))
65
+ random.shuffle(self.patch_indices_B)
66
+
67
+ def __getitem__(self, index):
68
+ """Return a data point and its metadata information.
69
+
70
+ Parameters:
71
+ index (int) -- a random integer for data indexing
72
+
73
+ Returns a dictionary that contains A, B, A_paths and B_paths
74
+ A (tensor) -- an image in the input domain
75
+ B (tensor) -- its corresponding image in the target domain
76
+ A_paths (str) -- image paths
77
+ B_paths (str) -- image paths
78
+ """
79
+ A_path = self.A_paths[0]
80
+ B_path = self.B_paths[0]
81
+ A_img = self.A_img
82
+ B_img = self.B_img
83
+
84
+ # apply image transformation
85
+ if self.opt.phase == "train":
86
+ param = {'scale_factor': self.zoom_levels_A[index],
87
+ 'patch_index': self.patch_indices_A[index],
88
+ 'flip': random.random() > 0.5}
89
+
90
+ transform_A = get_transform(self.opt, params=param, method=Image.BILINEAR)
91
+ A = transform_A(A_img)
92
+
93
+ param = {'scale_factor': self.zoom_levels_B[index],
94
+ 'patch_index': self.patch_indices_B[index],
95
+ 'flip': random.random() > 0.5}
96
+ transform_B = get_transform(self.opt, params=param, method=Image.BILINEAR)
97
+ B = transform_B(B_img)
98
+ else:
99
+ transform = get_transform(self.opt, method=Image.BILINEAR)
100
+ A = transform(A_img)
101
+ B = transform(B_img)
102
+
103
+ return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
104
+
105
+ def __len__(self):
106
+ """ Let's pretend the single image contains 100,000 crops for convenience.
107
+ """
108
+ return 100000
data/template_dataset.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset class template
2
+
3
+ This module provides a template for users to implement custom datasets.
4
+ You can specify '--dataset_mode template' to use this dataset.
5
+ The class name should be consistent with both the filename and its dataset_mode option.
6
+ The filename should be <dataset_mode>_dataset.py
7
+ The class name should be <Dataset_mode>Dataset.py
8
+ You need to implement the following functions:
9
+ -- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
10
+ -- <__init__>: Initialize this dataset class.
11
+ -- <__getitem__>: Return a data point and its metadata information.
12
+ -- <__len__>: Return the number of images.
13
+ """
14
+ from data.base_dataset import BaseDataset, get_transform
15
+ # from data.image_folder import make_dataset
16
+ # from PIL import Image
17
+
18
+
19
+ class TemplateDataset(BaseDataset):
20
+ """A template dataset class for you to implement custom datasets."""
21
+ @staticmethod
22
+ def modify_commandline_options(parser, is_train):
23
+ """Add new dataset-specific options, and rewrite default values for existing options.
24
+
25
+ Parameters:
26
+ parser -- original option parser
27
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
28
+
29
+ Returns:
30
+ the modified parser.
31
+ """
32
+ parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
33
+ parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
34
+ return parser
35
+
36
+ def __init__(self, opt):
37
+ """Initialize this dataset class.
38
+
39
+ Parameters:
40
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
41
+
42
+ A few things can be done here.
43
+ - save the options (have been done in BaseDataset)
44
+ - get image paths and meta information of the dataset.
45
+ - define the image transformation.
46
+ """
47
+ # save the option and dataset root
48
+ BaseDataset.__init__(self, opt)
49
+ # get the image paths of your dataset;
50
+ self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
51
+ # define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
52
+ self.transform = get_transform(opt)
53
+
54
+ def __getitem__(self, index):
55
+ """Return a data point and its metadata information.
56
+
57
+ Parameters:
58
+ index -- a random integer for data indexing
59
+
60
+ Returns:
61
+ a dictionary of data with their names. It usually contains the data itself and its metadata information.
62
+
63
+ Step 1: get a random image path: e.g., path = self.image_paths[index]
64
+ Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
65
+ Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
66
+ Step 4: return a data point as a dictionary.
67
+ """
68
+ path = 'temp' # needs to be a string
69
+ data_A = None # needs to be a tensor
70
+ data_B = None # needs to be a tensor
71
+ return {'data_A': data_A, 'data_B': data_B, 'path': path}
72
+
73
+ def __len__(self):
74
+ """Return the total number of images."""
75
+ return len(self.image_paths)
data/unaligned_dataset.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from data.base_dataset import BaseDataset, get_transform
3
+ from data.image_folder import make_dataset
4
+ from PIL import Image
5
+ import random
6
+ import util.util as util
7
+
8
+
9
+ class UnalignedDataset(BaseDataset):
10
+ """
11
+ This dataset class can load unaligned/unpaired datasets.
12
+
13
+ It requires two directories to host training images from domain A '/path/to/data/trainA'
14
+ and from domain B '/path/to/data/trainB' respectively.
15
+ You can train the model with the dataset flag '--dataroot /path/to/data'.
16
+ Similarly, you need to prepare two directories:
17
+ '/path/to/data/testA' and '/path/to/data/testB' during test time.
18
+ """
19
+
20
+ def __init__(self, opt):
21
+ """Initialize this dataset class.
22
+
23
+ Parameters:
24
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
25
+ """
26
+ BaseDataset.__init__(self, opt)
27
+ self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA'
28
+ self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB'
29
+
30
+ if opt.phase == "test" and not os.path.exists(self.dir_A) \
31
+ and os.path.exists(os.path.join(opt.dataroot, "valA")):
32
+ self.dir_A = os.path.join(opt.dataroot, "valA")
33
+ self.dir_B = os.path.join(opt.dataroot, "valB")
34
+
35
+ self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
36
+ self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
37
+ self.A_size = len(self.A_paths) # get the size of dataset A
38
+ self.B_size = len(self.B_paths) # get the size of dataset B
39
+
40
+ def __getitem__(self, index):
41
+ """Return a data point and its metadata information.
42
+
43
+ Parameters:
44
+ index (int) -- a random integer for data indexing
45
+
46
+ Returns a dictionary that contains A, B, A_paths and B_paths
47
+ A (tensor) -- an image in the input domain
48
+ B (tensor) -- its corresponding image in the target domain
49
+ A_paths (str) -- image paths
50
+ B_paths (str) -- image paths
51
+ """
52
+ A_path = self.A_paths[index % self.A_size] # make sure index is within then range
53
+ if self.opt.serial_batches: # make sure index is within then range
54
+ index_B = index % self.B_size
55
+ else: # randomize the index for domain B to avoid fixed pairs.
56
+ index_B = random.randint(0, self.B_size - 1)
57
+ B_path = self.B_paths[index_B]
58
+ A_img = Image.open(A_path).convert('RGB')
59
+ B_img = Image.open(B_path).convert('RGB')
60
+
61
+ # Apply image transformation
62
+ # For CUT/FastCUT mode, if in finetuning phase (learning rate is decaying),
63
+ # do not perform resize-crop data augmentation of CycleGAN.
64
+ is_finetuning = self.opt.isTrain and self.current_epoch > self.opt.n_epochs
65
+ modified_opt = util.copyconf(self.opt, load_size=self.opt.crop_size if is_finetuning else self.opt.load_size)
66
+ transform = get_transform(modified_opt)
67
+ A = transform(A_img)
68
+ B = transform(B_img)
69
+
70
+ return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
71
+
72
+ def __len__(self):
73
+ """Return the total number of images in the dataset.
74
+
75
+ As we have two datasets with potentially different number of images,
76
+ we take a maximum of
77
+ """
78
+ return max(self.A_size, self.B_size)