Spaces:
Runtime error
Runtime error
first commit
Browse files- data/__init__.py +98 -0
- data/base_dataset.py +230 -0
- data/image_folder.py +66 -0
- data/single_dataset.py +40 -0
- data/singleimage_dataset.py +108 -0
- data/template_dataset.py +75 -0
- data/unaligned_dataset.py +78 -0
data/__init__.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This package includes all the modules related to data loading and preprocessing
|
| 2 |
+
|
| 3 |
+
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
|
| 4 |
+
You need to implement four functions:
|
| 5 |
+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
| 6 |
+
-- <__len__>: return the size of dataset.
|
| 7 |
+
-- <__getitem__>: get a data point from data loader.
|
| 8 |
+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
| 9 |
+
|
| 10 |
+
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
|
| 11 |
+
See our template dataset class 'template_dataset.py' for more details.
|
| 12 |
+
"""
|
| 13 |
+
import importlib
|
| 14 |
+
import torch.utils.data
|
| 15 |
+
from data.base_dataset import BaseDataset
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def find_dataset_using_name(dataset_name):
|
| 19 |
+
"""Import the module "data/[dataset_name]_dataset.py".
|
| 20 |
+
|
| 21 |
+
In the file, the class called DatasetNameDataset() will
|
| 22 |
+
be instantiated. It has to be a subclass of BaseDataset,
|
| 23 |
+
and it is case-insensitive.
|
| 24 |
+
"""
|
| 25 |
+
dataset_filename = "data." + dataset_name + "_dataset"
|
| 26 |
+
datasetlib = importlib.import_module(dataset_filename)
|
| 27 |
+
|
| 28 |
+
dataset = None
|
| 29 |
+
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
|
| 30 |
+
for name, cls in datasetlib.__dict__.items():
|
| 31 |
+
if name.lower() == target_dataset_name.lower() \
|
| 32 |
+
and issubclass(cls, BaseDataset):
|
| 33 |
+
dataset = cls
|
| 34 |
+
|
| 35 |
+
if dataset is None:
|
| 36 |
+
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
|
| 37 |
+
|
| 38 |
+
return dataset
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_option_setter(dataset_name):
|
| 42 |
+
"""Return the static method <modify_commandline_options> of the dataset class."""
|
| 43 |
+
dataset_class = find_dataset_using_name(dataset_name)
|
| 44 |
+
return dataset_class.modify_commandline_options
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def create_dataset(opt):
|
| 48 |
+
"""Create a dataset given the option.
|
| 49 |
+
|
| 50 |
+
This function wraps the class CustomDatasetDataLoader.
|
| 51 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
| 52 |
+
|
| 53 |
+
Example:
|
| 54 |
+
>>> from data import create_dataset
|
| 55 |
+
>>> dataset = create_dataset(opt)
|
| 56 |
+
"""
|
| 57 |
+
data_loader = CustomDatasetDataLoader(opt)
|
| 58 |
+
dataset = data_loader.load_data()
|
| 59 |
+
return dataset
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class CustomDatasetDataLoader():
|
| 63 |
+
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
|
| 64 |
+
|
| 65 |
+
def __init__(self, opt):
|
| 66 |
+
"""Initialize this class
|
| 67 |
+
|
| 68 |
+
Step 1: create a dataset instance given the name [dataset_mode]
|
| 69 |
+
Step 2: create a multi-threaded data loader.
|
| 70 |
+
"""
|
| 71 |
+
self.opt = opt
|
| 72 |
+
dataset_class = find_dataset_using_name(opt.dataset_mode)
|
| 73 |
+
self.dataset = dataset_class(opt)
|
| 74 |
+
print("dataset [%s] was created" % type(self.dataset).__name__)
|
| 75 |
+
self.dataloader = torch.utils.data.DataLoader(
|
| 76 |
+
self.dataset,
|
| 77 |
+
batch_size=opt.batch_size,
|
| 78 |
+
shuffle=not opt.serial_batches,
|
| 79 |
+
num_workers=int(opt.num_threads),
|
| 80 |
+
drop_last=True if opt.isTrain else False,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def set_epoch(self, epoch):
|
| 84 |
+
self.dataset.current_epoch = epoch
|
| 85 |
+
|
| 86 |
+
def load_data(self):
|
| 87 |
+
return self
|
| 88 |
+
|
| 89 |
+
def __len__(self):
|
| 90 |
+
"""Return the number of data in the dataset"""
|
| 91 |
+
return min(len(self.dataset), self.opt.max_dataset_size)
|
| 92 |
+
|
| 93 |
+
def __iter__(self):
|
| 94 |
+
"""Return a batch of data"""
|
| 95 |
+
for i, data in enumerate(self.dataloader):
|
| 96 |
+
if i * self.opt.batch_size >= self.opt.max_dataset_size:
|
| 97 |
+
break
|
| 98 |
+
yield data
|
data/base_dataset.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
|
| 2 |
+
|
| 3 |
+
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
|
| 4 |
+
"""
|
| 5 |
+
import random
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch.utils.data as data
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import torchvision.transforms as transforms
|
| 10 |
+
from abc import ABC, abstractmethod
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BaseDataset(data.Dataset, ABC):
|
| 14 |
+
"""This class is an abstract base class (ABC) for datasets.
|
| 15 |
+
|
| 16 |
+
To create a subclass, you need to implement the following four functions:
|
| 17 |
+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
| 18 |
+
-- <__len__>: return the size of dataset.
|
| 19 |
+
-- <__getitem__>: get a data point.
|
| 20 |
+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, opt):
|
| 24 |
+
"""Initialize the class; save the options in the class
|
| 25 |
+
|
| 26 |
+
Parameters:
|
| 27 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
| 28 |
+
"""
|
| 29 |
+
self.opt = opt
|
| 30 |
+
self.root = opt.dataroot
|
| 31 |
+
self.current_epoch = 0
|
| 32 |
+
|
| 33 |
+
@staticmethod
|
| 34 |
+
def modify_commandline_options(parser, is_train):
|
| 35 |
+
"""Add new dataset-specific options, and rewrite default values for existing options.
|
| 36 |
+
|
| 37 |
+
Parameters:
|
| 38 |
+
parser -- original option parser
|
| 39 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
the modified parser.
|
| 43 |
+
"""
|
| 44 |
+
return parser
|
| 45 |
+
|
| 46 |
+
@abstractmethod
|
| 47 |
+
def __len__(self):
|
| 48 |
+
"""Return the total number of images in the dataset."""
|
| 49 |
+
return 0
|
| 50 |
+
|
| 51 |
+
@abstractmethod
|
| 52 |
+
def __getitem__(self, index):
|
| 53 |
+
"""Return a data point and its metadata information.
|
| 54 |
+
|
| 55 |
+
Parameters:
|
| 56 |
+
index - - a random integer for data indexing
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
|
| 60 |
+
"""
|
| 61 |
+
pass
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_params(opt, size):
|
| 65 |
+
w, h = size
|
| 66 |
+
new_h = h
|
| 67 |
+
new_w = w
|
| 68 |
+
if opt.preprocess == 'resize_and_crop':
|
| 69 |
+
new_h = new_w = opt.load_size
|
| 70 |
+
elif opt.preprocess == 'scale_width_and_crop':
|
| 71 |
+
new_w = opt.load_size
|
| 72 |
+
new_h = opt.load_size * h // w
|
| 73 |
+
|
| 74 |
+
x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
|
| 75 |
+
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
|
| 76 |
+
|
| 77 |
+
flip = random.random() > 0.5
|
| 78 |
+
|
| 79 |
+
return {'crop_pos': (x, y), 'flip': flip}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
|
| 83 |
+
transform_list = []
|
| 84 |
+
if grayscale:
|
| 85 |
+
transform_list.append(transforms.Grayscale(1))
|
| 86 |
+
if 'fixsize' in opt.preprocess:
|
| 87 |
+
transform_list.append(transforms.Resize(params["size"], method))
|
| 88 |
+
if 'resize' in opt.preprocess:
|
| 89 |
+
osize = [opt.load_size, opt.load_size]
|
| 90 |
+
if "gta2cityscapes" in opt.dataroot:
|
| 91 |
+
osize[0] = opt.load_size // 2
|
| 92 |
+
transform_list.append(transforms.Resize(osize, method))
|
| 93 |
+
elif 'scale_width' in opt.preprocess:
|
| 94 |
+
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))
|
| 95 |
+
elif 'scale_shortside' in opt.preprocess:
|
| 96 |
+
transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, opt.crop_size, method)))
|
| 97 |
+
|
| 98 |
+
if 'zoom' in opt.preprocess:
|
| 99 |
+
if params is None:
|
| 100 |
+
transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method)))
|
| 101 |
+
else:
|
| 102 |
+
transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method, factor=params["scale_factor"])))
|
| 103 |
+
|
| 104 |
+
if 'crop' in opt.preprocess:
|
| 105 |
+
if params is None or 'crop_pos' not in params:
|
| 106 |
+
transform_list.append(transforms.RandomCrop(opt.crop_size))
|
| 107 |
+
else:
|
| 108 |
+
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
|
| 109 |
+
|
| 110 |
+
if 'patch' in opt.preprocess:
|
| 111 |
+
transform_list.append(transforms.Lambda(lambda img: __patch(img, params['patch_index'], opt.crop_size)))
|
| 112 |
+
|
| 113 |
+
if 'trim' in opt.preprocess:
|
| 114 |
+
transform_list.append(transforms.Lambda(lambda img: __trim(img, opt.crop_size)))
|
| 115 |
+
|
| 116 |
+
# if opt.preprocess == 'none':
|
| 117 |
+
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
|
| 118 |
+
|
| 119 |
+
if not opt.no_flip:
|
| 120 |
+
if params is None or 'flip' not in params:
|
| 121 |
+
transform_list.append(transforms.RandomHorizontalFlip())
|
| 122 |
+
elif 'flip' in params:
|
| 123 |
+
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
|
| 124 |
+
|
| 125 |
+
if convert:
|
| 126 |
+
transform_list += [transforms.ToTensor()]
|
| 127 |
+
if grayscale:
|
| 128 |
+
transform_list += [transforms.Normalize((0.5,), (0.5,))]
|
| 129 |
+
else:
|
| 130 |
+
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
| 131 |
+
return transforms.Compose(transform_list)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def __make_power_2(img, base, method=Image.BICUBIC):
|
| 135 |
+
ow, oh = img.size
|
| 136 |
+
h = int(round(oh / base) * base)
|
| 137 |
+
w = int(round(ow / base) * base)
|
| 138 |
+
if h == oh and w == ow:
|
| 139 |
+
return img
|
| 140 |
+
|
| 141 |
+
return img.resize((w, h), method)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def __random_zoom(img, target_width, crop_width, method=Image.BICUBIC, factor=None):
|
| 145 |
+
if factor is None:
|
| 146 |
+
zoom_level = np.random.uniform(0.8, 1.0, size=[2])
|
| 147 |
+
else:
|
| 148 |
+
zoom_level = (factor[0], factor[1])
|
| 149 |
+
iw, ih = img.size
|
| 150 |
+
zoomw = max(crop_width, iw * zoom_level[0])
|
| 151 |
+
zoomh = max(crop_width, ih * zoom_level[1])
|
| 152 |
+
img = img.resize((int(round(zoomw)), int(round(zoomh))), method)
|
| 153 |
+
return img
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def __scale_shortside(img, target_width, crop_width, method=Image.BICUBIC):
|
| 157 |
+
ow, oh = img.size
|
| 158 |
+
shortside = min(ow, oh)
|
| 159 |
+
if shortside >= target_width:
|
| 160 |
+
return img
|
| 161 |
+
else:
|
| 162 |
+
scale = target_width / shortside
|
| 163 |
+
return img.resize((round(ow * scale), round(oh * scale)), method)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def __trim(img, trim_width):
|
| 167 |
+
ow, oh = img.size
|
| 168 |
+
if ow > trim_width:
|
| 169 |
+
xstart = np.random.randint(ow - trim_width)
|
| 170 |
+
xend = xstart + trim_width
|
| 171 |
+
else:
|
| 172 |
+
xstart = 0
|
| 173 |
+
xend = ow
|
| 174 |
+
if oh > trim_width:
|
| 175 |
+
ystart = np.random.randint(oh - trim_width)
|
| 176 |
+
yend = ystart + trim_width
|
| 177 |
+
else:
|
| 178 |
+
ystart = 0
|
| 179 |
+
yend = oh
|
| 180 |
+
return img.crop((xstart, ystart, xend, yend))
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def __scale_width(img, target_width, crop_width, method=Image.BICUBIC):
|
| 184 |
+
ow, oh = img.size
|
| 185 |
+
if ow == target_width and oh >= crop_width:
|
| 186 |
+
return img
|
| 187 |
+
w = target_width
|
| 188 |
+
h = int(max(target_width * oh / ow, crop_width))
|
| 189 |
+
return img.resize((w, h), method)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def __crop(img, pos, size):
|
| 193 |
+
ow, oh = img.size
|
| 194 |
+
x1, y1 = pos
|
| 195 |
+
tw = th = size
|
| 196 |
+
if (ow > tw or oh > th):
|
| 197 |
+
return img.crop((x1, y1, x1 + tw, y1 + th))
|
| 198 |
+
return img
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def __patch(img, index, size):
|
| 202 |
+
ow, oh = img.size
|
| 203 |
+
nw, nh = ow // size, oh // size
|
| 204 |
+
roomx = ow - nw * size
|
| 205 |
+
roomy = oh - nh * size
|
| 206 |
+
startx = np.random.randint(int(roomx) + 1)
|
| 207 |
+
starty = np.random.randint(int(roomy) + 1)
|
| 208 |
+
|
| 209 |
+
index = index % (nw * nh)
|
| 210 |
+
ix = index // nh
|
| 211 |
+
iy = index % nh
|
| 212 |
+
gridx = startx + ix * size
|
| 213 |
+
gridy = starty + iy * size
|
| 214 |
+
return img.crop((gridx, gridy, gridx + size, gridy + size))
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def __flip(img, flip):
|
| 218 |
+
if flip:
|
| 219 |
+
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
| 220 |
+
return img
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def __print_size_warning(ow, oh, w, h):
|
| 224 |
+
"""Print warning information about image size(only print once)"""
|
| 225 |
+
if not hasattr(__print_size_warning, 'has_printed'):
|
| 226 |
+
print("The image size needs to be a multiple of 4. "
|
| 227 |
+
"The loaded image size was (%d, %d), so it was adjusted to "
|
| 228 |
+
"(%d, %d). This adjustment will be done to all images "
|
| 229 |
+
"whose sizes are not multiples of 4" % (ow, oh, w, h))
|
| 230 |
+
__print_size_warning.has_printed = True
|
data/image_folder.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""A modified image folder class
|
| 2 |
+
|
| 3 |
+
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
|
| 4 |
+
so that this class can load images from both current directory and its subdirectories.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch.utils.data as data
|
| 8 |
+
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import os
|
| 11 |
+
import os.path
|
| 12 |
+
|
| 13 |
+
IMG_EXTENSIONS = [
|
| 14 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
| 15 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
| 16 |
+
'.tif', '.TIF', '.tiff', '.TIFF',
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def is_image_file(filename):
|
| 21 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def make_dataset(dir, max_dataset_size=float("inf")):
|
| 25 |
+
images = []
|
| 26 |
+
assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
|
| 27 |
+
|
| 28 |
+
for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
|
| 29 |
+
for fname in fnames:
|
| 30 |
+
if is_image_file(fname):
|
| 31 |
+
path = os.path.join(root, fname)
|
| 32 |
+
images.append(path)
|
| 33 |
+
return images[:min(max_dataset_size, len(images))]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def default_loader(path):
|
| 37 |
+
return Image.open(path).convert('RGB')
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ImageFolder(data.Dataset):
|
| 41 |
+
|
| 42 |
+
def __init__(self, root, transform=None, return_paths=False,
|
| 43 |
+
loader=default_loader):
|
| 44 |
+
imgs = make_dataset(root)
|
| 45 |
+
if len(imgs) == 0:
|
| 46 |
+
raise(RuntimeError("Found 0 images in: " + root + "\n"
|
| 47 |
+
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
|
| 48 |
+
|
| 49 |
+
self.root = root
|
| 50 |
+
self.imgs = imgs
|
| 51 |
+
self.transform = transform
|
| 52 |
+
self.return_paths = return_paths
|
| 53 |
+
self.loader = loader
|
| 54 |
+
|
| 55 |
+
def __getitem__(self, index):
|
| 56 |
+
path = self.imgs[index]
|
| 57 |
+
img = self.loader(path)
|
| 58 |
+
if self.transform is not None:
|
| 59 |
+
img = self.transform(img)
|
| 60 |
+
if self.return_paths:
|
| 61 |
+
return img, path
|
| 62 |
+
else:
|
| 63 |
+
return img
|
| 64 |
+
|
| 65 |
+
def __len__(self):
|
| 66 |
+
return len(self.imgs)
|
data/single_dataset.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from data.base_dataset import BaseDataset, get_transform
|
| 2 |
+
from data.image_folder import make_dataset
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SingleDataset(BaseDataset):
|
| 7 |
+
"""This dataset class can load a set of images specified by the path --dataroot /path/to/data.
|
| 8 |
+
|
| 9 |
+
It can be used for generating CycleGAN results only for one side with the model option '-model test'.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, opt):
|
| 13 |
+
"""Initialize this dataset class.
|
| 14 |
+
|
| 15 |
+
Parameters:
|
| 16 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
| 17 |
+
"""
|
| 18 |
+
BaseDataset.__init__(self, opt)
|
| 19 |
+
self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size))
|
| 20 |
+
input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
|
| 21 |
+
self.transform = get_transform(opt, grayscale=(input_nc == 1))
|
| 22 |
+
|
| 23 |
+
def __getitem__(self, index):
|
| 24 |
+
"""Return a data point and its metadata information.
|
| 25 |
+
|
| 26 |
+
Parameters:
|
| 27 |
+
index - - a random integer for data indexing
|
| 28 |
+
|
| 29 |
+
Returns a dictionary that contains A and A_paths
|
| 30 |
+
A(tensor) - - an image in one domain
|
| 31 |
+
A_paths(str) - - the path of the image
|
| 32 |
+
"""
|
| 33 |
+
A_path = self.A_paths[index]
|
| 34 |
+
A_img = Image.open(A_path).convert('RGB')
|
| 35 |
+
A = self.transform(A_img)
|
| 36 |
+
return {'A': A, 'A_paths': A_path}
|
| 37 |
+
|
| 38 |
+
def __len__(self):
|
| 39 |
+
"""Return the total number of images in the dataset."""
|
| 40 |
+
return len(self.A_paths)
|
data/singleimage_dataset.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os.path
|
| 3 |
+
from data.base_dataset import BaseDataset, get_transform
|
| 4 |
+
from data.image_folder import make_dataset
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import random
|
| 7 |
+
import util.util as util
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SingleImageDataset(BaseDataset):
|
| 11 |
+
"""
|
| 12 |
+
This dataset class can load unaligned/unpaired datasets.
|
| 13 |
+
|
| 14 |
+
It requires two directories to host training images from domain A '/path/to/data/trainA'
|
| 15 |
+
and from domain B '/path/to/data/trainB' respectively.
|
| 16 |
+
You can train the model with the dataset flag '--dataroot /path/to/data'.
|
| 17 |
+
Similarly, you need to prepare two directories:
|
| 18 |
+
'/path/to/data/testA' and '/path/to/data/testB' during test time.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, opt):
|
| 22 |
+
"""Initialize this dataset class.
|
| 23 |
+
|
| 24 |
+
Parameters:
|
| 25 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
| 26 |
+
"""
|
| 27 |
+
BaseDataset.__init__(self, opt)
|
| 28 |
+
|
| 29 |
+
self.dir_A = os.path.join(opt.dataroot, 'trainA') # create a path '/path/to/data/trainA'
|
| 30 |
+
self.dir_B = os.path.join(opt.dataroot, 'trainB') # create a path '/path/to/data/trainB'
|
| 31 |
+
|
| 32 |
+
if os.path.exists(self.dir_A) and os.path.exists(self.dir_B):
|
| 33 |
+
self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
|
| 34 |
+
self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
|
| 35 |
+
self.A_size = len(self.A_paths) # get the size of dataset A
|
| 36 |
+
self.B_size = len(self.B_paths) # get the size of dataset B
|
| 37 |
+
|
| 38 |
+
assert len(self.A_paths) == 1 and len(self.B_paths) == 1,\
|
| 39 |
+
"SingleImageDataset class should be used with one image in each domain"
|
| 40 |
+
A_img = Image.open(self.A_paths[0]).convert('RGB')
|
| 41 |
+
B_img = Image.open(self.B_paths[0]).convert('RGB')
|
| 42 |
+
print("Image sizes %s and %s" % (str(A_img.size), str(B_img.size)))
|
| 43 |
+
|
| 44 |
+
self.A_img = A_img
|
| 45 |
+
self.B_img = B_img
|
| 46 |
+
|
| 47 |
+
# In single-image translation, we augment the data loader by applying
|
| 48 |
+
# random scaling. Still, we design the data loader such that the
|
| 49 |
+
# amount of scaling is the same within a minibatch. To do this,
|
| 50 |
+
# we precompute the random scaling values, and repeat them by |batch_size|.
|
| 51 |
+
A_zoom = 1 / self.opt.random_scale_max
|
| 52 |
+
zoom_levels_A = np.random.uniform(A_zoom, 1.0, size=(len(self) // opt.batch_size + 1, 1, 2))
|
| 53 |
+
self.zoom_levels_A = np.reshape(np.tile(zoom_levels_A, (1, opt.batch_size, 1)), [-1, 2])
|
| 54 |
+
|
| 55 |
+
B_zoom = 1 / self.opt.random_scale_max
|
| 56 |
+
zoom_levels_B = np.random.uniform(B_zoom, 1.0, size=(len(self) // opt.batch_size + 1, 1, 2))
|
| 57 |
+
self.zoom_levels_B = np.reshape(np.tile(zoom_levels_B, (1, opt.batch_size, 1)), [-1, 2])
|
| 58 |
+
|
| 59 |
+
# While the crop locations are randomized, the negative samples should
|
| 60 |
+
# not come from the same location. To do this, we precompute the
|
| 61 |
+
# crop locations with no repetition.
|
| 62 |
+
self.patch_indices_A = list(range(len(self)))
|
| 63 |
+
random.shuffle(self.patch_indices_A)
|
| 64 |
+
self.patch_indices_B = list(range(len(self)))
|
| 65 |
+
random.shuffle(self.patch_indices_B)
|
| 66 |
+
|
| 67 |
+
def __getitem__(self, index):
|
| 68 |
+
"""Return a data point and its metadata information.
|
| 69 |
+
|
| 70 |
+
Parameters:
|
| 71 |
+
index (int) -- a random integer for data indexing
|
| 72 |
+
|
| 73 |
+
Returns a dictionary that contains A, B, A_paths and B_paths
|
| 74 |
+
A (tensor) -- an image in the input domain
|
| 75 |
+
B (tensor) -- its corresponding image in the target domain
|
| 76 |
+
A_paths (str) -- image paths
|
| 77 |
+
B_paths (str) -- image paths
|
| 78 |
+
"""
|
| 79 |
+
A_path = self.A_paths[0]
|
| 80 |
+
B_path = self.B_paths[0]
|
| 81 |
+
A_img = self.A_img
|
| 82 |
+
B_img = self.B_img
|
| 83 |
+
|
| 84 |
+
# apply image transformation
|
| 85 |
+
if self.opt.phase == "train":
|
| 86 |
+
param = {'scale_factor': self.zoom_levels_A[index],
|
| 87 |
+
'patch_index': self.patch_indices_A[index],
|
| 88 |
+
'flip': random.random() > 0.5}
|
| 89 |
+
|
| 90 |
+
transform_A = get_transform(self.opt, params=param, method=Image.BILINEAR)
|
| 91 |
+
A = transform_A(A_img)
|
| 92 |
+
|
| 93 |
+
param = {'scale_factor': self.zoom_levels_B[index],
|
| 94 |
+
'patch_index': self.patch_indices_B[index],
|
| 95 |
+
'flip': random.random() > 0.5}
|
| 96 |
+
transform_B = get_transform(self.opt, params=param, method=Image.BILINEAR)
|
| 97 |
+
B = transform_B(B_img)
|
| 98 |
+
else:
|
| 99 |
+
transform = get_transform(self.opt, method=Image.BILINEAR)
|
| 100 |
+
A = transform(A_img)
|
| 101 |
+
B = transform(B_img)
|
| 102 |
+
|
| 103 |
+
return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
|
| 104 |
+
|
| 105 |
+
def __len__(self):
|
| 106 |
+
""" Let's pretend the single image contains 100,000 crops for convenience.
|
| 107 |
+
"""
|
| 108 |
+
return 100000
|
data/template_dataset.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dataset class template
|
| 2 |
+
|
| 3 |
+
This module provides a template for users to implement custom datasets.
|
| 4 |
+
You can specify '--dataset_mode template' to use this dataset.
|
| 5 |
+
The class name should be consistent with both the filename and its dataset_mode option.
|
| 6 |
+
The filename should be <dataset_mode>_dataset.py
|
| 7 |
+
The class name should be <Dataset_mode>Dataset.py
|
| 8 |
+
You need to implement the following functions:
|
| 9 |
+
-- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
|
| 10 |
+
-- <__init__>: Initialize this dataset class.
|
| 11 |
+
-- <__getitem__>: Return a data point and its metadata information.
|
| 12 |
+
-- <__len__>: Return the number of images.
|
| 13 |
+
"""
|
| 14 |
+
from data.base_dataset import BaseDataset, get_transform
|
| 15 |
+
# from data.image_folder import make_dataset
|
| 16 |
+
# from PIL import Image
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TemplateDataset(BaseDataset):
|
| 20 |
+
"""A template dataset class for you to implement custom datasets."""
|
| 21 |
+
@staticmethod
|
| 22 |
+
def modify_commandline_options(parser, is_train):
|
| 23 |
+
"""Add new dataset-specific options, and rewrite default values for existing options.
|
| 24 |
+
|
| 25 |
+
Parameters:
|
| 26 |
+
parser -- original option parser
|
| 27 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
the modified parser.
|
| 31 |
+
"""
|
| 32 |
+
parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
|
| 33 |
+
parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
|
| 34 |
+
return parser
|
| 35 |
+
|
| 36 |
+
def __init__(self, opt):
|
| 37 |
+
"""Initialize this dataset class.
|
| 38 |
+
|
| 39 |
+
Parameters:
|
| 40 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
| 41 |
+
|
| 42 |
+
A few things can be done here.
|
| 43 |
+
- save the options (have been done in BaseDataset)
|
| 44 |
+
- get image paths and meta information of the dataset.
|
| 45 |
+
- define the image transformation.
|
| 46 |
+
"""
|
| 47 |
+
# save the option and dataset root
|
| 48 |
+
BaseDataset.__init__(self, opt)
|
| 49 |
+
# get the image paths of your dataset;
|
| 50 |
+
self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
|
| 51 |
+
# define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
|
| 52 |
+
self.transform = get_transform(opt)
|
| 53 |
+
|
| 54 |
+
def __getitem__(self, index):
|
| 55 |
+
"""Return a data point and its metadata information.
|
| 56 |
+
|
| 57 |
+
Parameters:
|
| 58 |
+
index -- a random integer for data indexing
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
a dictionary of data with their names. It usually contains the data itself and its metadata information.
|
| 62 |
+
|
| 63 |
+
Step 1: get a random image path: e.g., path = self.image_paths[index]
|
| 64 |
+
Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
|
| 65 |
+
Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
|
| 66 |
+
Step 4: return a data point as a dictionary.
|
| 67 |
+
"""
|
| 68 |
+
path = 'temp' # needs to be a string
|
| 69 |
+
data_A = None # needs to be a tensor
|
| 70 |
+
data_B = None # needs to be a tensor
|
| 71 |
+
return {'data_A': data_A, 'data_B': data_B, 'path': path}
|
| 72 |
+
|
| 73 |
+
def __len__(self):
|
| 74 |
+
"""Return the total number of images."""
|
| 75 |
+
return len(self.image_paths)
|
data/unaligned_dataset.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
from data.base_dataset import BaseDataset, get_transform
|
| 3 |
+
from data.image_folder import make_dataset
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import random
|
| 6 |
+
import util.util as util
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class UnalignedDataset(BaseDataset):
|
| 10 |
+
"""
|
| 11 |
+
This dataset class can load unaligned/unpaired datasets.
|
| 12 |
+
|
| 13 |
+
It requires two directories to host training images from domain A '/path/to/data/trainA'
|
| 14 |
+
and from domain B '/path/to/data/trainB' respectively.
|
| 15 |
+
You can train the model with the dataset flag '--dataroot /path/to/data'.
|
| 16 |
+
Similarly, you need to prepare two directories:
|
| 17 |
+
'/path/to/data/testA' and '/path/to/data/testB' during test time.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, opt):
|
| 21 |
+
"""Initialize this dataset class.
|
| 22 |
+
|
| 23 |
+
Parameters:
|
| 24 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
| 25 |
+
"""
|
| 26 |
+
BaseDataset.__init__(self, opt)
|
| 27 |
+
self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA'
|
| 28 |
+
self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB'
|
| 29 |
+
|
| 30 |
+
if opt.phase == "test" and not os.path.exists(self.dir_A) \
|
| 31 |
+
and os.path.exists(os.path.join(opt.dataroot, "valA")):
|
| 32 |
+
self.dir_A = os.path.join(opt.dataroot, "valA")
|
| 33 |
+
self.dir_B = os.path.join(opt.dataroot, "valB")
|
| 34 |
+
|
| 35 |
+
self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
|
| 36 |
+
self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
|
| 37 |
+
self.A_size = len(self.A_paths) # get the size of dataset A
|
| 38 |
+
self.B_size = len(self.B_paths) # get the size of dataset B
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, index):
|
| 41 |
+
"""Return a data point and its metadata information.
|
| 42 |
+
|
| 43 |
+
Parameters:
|
| 44 |
+
index (int) -- a random integer for data indexing
|
| 45 |
+
|
| 46 |
+
Returns a dictionary that contains A, B, A_paths and B_paths
|
| 47 |
+
A (tensor) -- an image in the input domain
|
| 48 |
+
B (tensor) -- its corresponding image in the target domain
|
| 49 |
+
A_paths (str) -- image paths
|
| 50 |
+
B_paths (str) -- image paths
|
| 51 |
+
"""
|
| 52 |
+
A_path = self.A_paths[index % self.A_size] # make sure index is within then range
|
| 53 |
+
if self.opt.serial_batches: # make sure index is within then range
|
| 54 |
+
index_B = index % self.B_size
|
| 55 |
+
else: # randomize the index for domain B to avoid fixed pairs.
|
| 56 |
+
index_B = random.randint(0, self.B_size - 1)
|
| 57 |
+
B_path = self.B_paths[index_B]
|
| 58 |
+
A_img = Image.open(A_path).convert('RGB')
|
| 59 |
+
B_img = Image.open(B_path).convert('RGB')
|
| 60 |
+
|
| 61 |
+
# Apply image transformation
|
| 62 |
+
# For CUT/FastCUT mode, if in finetuning phase (learning rate is decaying),
|
| 63 |
+
# do not perform resize-crop data augmentation of CycleGAN.
|
| 64 |
+
is_finetuning = self.opt.isTrain and self.current_epoch > self.opt.n_epochs
|
| 65 |
+
modified_opt = util.copyconf(self.opt, load_size=self.opt.crop_size if is_finetuning else self.opt.load_size)
|
| 66 |
+
transform = get_transform(modified_opt)
|
| 67 |
+
A = transform(A_img)
|
| 68 |
+
B = transform(B_img)
|
| 69 |
+
|
| 70 |
+
return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
|
| 71 |
+
|
| 72 |
+
def __len__(self):
|
| 73 |
+
"""Return the total number of images in the dataset.
|
| 74 |
+
|
| 75 |
+
As we have two datasets with potentially different number of images,
|
| 76 |
+
we take a maximum of
|
| 77 |
+
"""
|
| 78 |
+
return max(self.A_size, self.B_size)
|