| | import torch.utils.data as data |
| | from PIL import Image |
| | import torchvision.transforms as transforms |
| |
|
| |
|
| | class BaseDataset(data.Dataset): |
| | def __init__(self): |
| | super(BaseDataset, self).__init__() |
| |
|
| | def name(self): |
| | return 'BaseDataset' |
| |
|
| | def initialize(self, opt): |
| | pass |
| |
|
| | def __len__(self): |
| | return 0 |
| |
|
| |
|
| | def get_transform(opt): |
| | transform_list = [] |
| | if opt.resize_or_crop == 'resize_and_crop': |
| | osize = [opt.loadSize, opt.loadSize] |
| | transform_list.append(transforms.Resize(osize, Image.BICUBIC)) |
| | transform_list.append(transforms.RandomCrop(opt.fineSize)) |
| | elif opt.resize_or_crop == 'crop': |
| | transform_list.append(transforms.RandomCrop(opt.fineSize)) |
| | elif opt.resize_or_crop == 'scale_width': |
| | transform_list.append(transforms.Lambda( |
| | lambda img: __scale_width(img, opt.fineSize))) |
| | elif opt.resize_or_crop == 'scale_width_and_crop': |
| | transform_list.append(transforms.Lambda( |
| | lambda img: __scale_width(img, opt.loadSize))) |
| | transform_list.append(transforms.RandomCrop(opt.fineSize)) |
| | elif opt.resize_or_crop == 'none': |
| | transform_list.append(transforms.Lambda( |
| | lambda img: __adjust(img))) |
| | else: |
| | raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop) |
| |
|
| | if opt.isTrain and not opt.no_flip: |
| | transform_list.append(transforms.RandomHorizontalFlip()) |
| |
|
| | transform_list += [transforms.ToTensor(), |
| | transforms.Normalize((0.5, 0.5, 0.5), |
| | (0.5, 0.5, 0.5))] |
| | return transforms.Compose(transform_list) |
| |
|
| |
|
| | |
| | def __adjust(img): |
| | ow, oh = img.size |
| |
|
| | |
| | |
| | |
| | mult = 4 |
| | if ow % mult == 0 and oh % mult == 0: |
| | return img |
| | w = (ow - 1) // mult |
| | w = (w + 1) * mult |
| | h = (oh - 1) // mult |
| | h = (h + 1) * mult |
| |
|
| | if ow != w or oh != h: |
| | __print_size_warning(ow, oh, w, h) |
| |
|
| | return img.resize((w, h), Image.BICUBIC) |
| |
|
| |
|
| | def __scale_width(img, target_width): |
| | ow, oh = img.size |
| |
|
| | |
| | |
| | |
| | mult = 4 |
| | assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult |
| | if (ow == target_width and oh % mult == 0): |
| | return img |
| | w = target_width |
| | target_height = int(target_width * oh / ow) |
| | m = (target_height - 1) // mult |
| | h = (m + 1) * mult |
| |
|
| | if target_height != h: |
| | __print_size_warning(target_width, target_height, w, h) |
| |
|
| | return img.resize((w, h), Image.BICUBIC) |
| |
|
| |
|
| | def __print_size_warning(ow, oh, w, h): |
| | if not hasattr(__print_size_warning, 'has_printed'): |
| | print("The image size needs to be a multiple of 4. " |
| | "The loaded image size was (%d, %d), so it was adjusted to " |
| | "(%d, %d). This adjustment will be done to all images " |
| | "whose sizes are not multiples of 4" % (ow, oh, w, h)) |
| | __print_size_warning.has_printed = True |