|
|
import torch.utils.data as data |
|
|
from PIL import Image |
|
|
import torchvision.transforms as transforms |
|
|
import numpy as np |
|
|
import random |
|
|
|
|
|
|
|
|
class BaseDataset(data.Dataset): |
|
|
def __init__(self,opt=None): |
|
|
super(BaseDataset, self).__init__() |
|
|
if opt is not None: |
|
|
self.opt = opt |
|
|
|
|
|
@staticmethod |
|
|
def modify_commandline_options(parser, is_train): |
|
|
return parser |
|
|
|
|
|
def initialize(self, opt): |
|
|
pass |
|
|
|
|
|
|
|
|
def get_params(opt, size): |
|
|
w, h = size |
|
|
new_h = h |
|
|
new_w = w |
|
|
if opt.preprocess_mode == 'resize_and_crop': |
|
|
new_h = new_w = opt.load_size |
|
|
elif opt.preprocess_mode == 'scale_width_and_crop': |
|
|
new_w = opt.load_size |
|
|
new_h = opt.load_size * h // w |
|
|
elif opt.preprocess_mode == 'scale_shortside_and_crop': |
|
|
ss, ls = min(w, h), max(w, h) |
|
|
width_is_shorter = w == ss |
|
|
ls = int(opt.load_size * ls / ss) |
|
|
new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss) |
|
|
|
|
|
x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) |
|
|
y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) |
|
|
|
|
|
flip = random.random() > 0.5 |
|
|
return {'crop_pos': (x, y), 'flip': flip} |
|
|
|
|
|
|
|
|
def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True): |
|
|
transform_list = [] |
|
|
if 'resize' in opt.preprocess_mode: |
|
|
osize = [opt.load_size, opt.load_size] |
|
|
transform_list.append(transforms.Resize(osize, interpolation=method)) |
|
|
elif 'scale_width' in opt.preprocess_mode: |
|
|
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) |
|
|
elif 'scale_shortside' in opt.preprocess_mode: |
|
|
transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method))) |
|
|
|
|
|
if 'crop' in opt.preprocess_mode: |
|
|
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) |
|
|
|
|
|
if opt.preprocess_mode == 'none': |
|
|
base = 32 |
|
|
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) |
|
|
|
|
|
if opt.preprocess_mode == 'fixed': |
|
|
w = opt.crop_size |
|
|
h = round(opt.crop_size / opt.aspect_ratio) |
|
|
transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method))) |
|
|
|
|
|
if opt.isTrain and not opt.no_flip: |
|
|
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) |
|
|
|
|
|
if toTensor: |
|
|
transform_list += [transforms.ToTensor()] |
|
|
|
|
|
if normalize: |
|
|
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), |
|
|
(0.5, 0.5, 0.5))] |
|
|
return transforms.Compose(transform_list) |
|
|
|
|
|
|
|
|
def normalize(): |
|
|
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
|
|
|
|
|
|
|
def __resize(img, w, h, method=Image.BICUBIC): |
|
|
return img.resize((w, h), method) |
|
|
|
|
|
|
|
|
def __make_power_2(img, base, method=Image.BICUBIC): |
|
|
ow, oh = img.size |
|
|
h = int(round(oh / base) * base) |
|
|
w = int(round(ow / base) * base) |
|
|
if (h == oh) and (w == ow): |
|
|
return img |
|
|
return img.resize((w, h), method) |
|
|
|
|
|
|
|
|
def __scale_width(img, target_width, method=Image.BICUBIC): |
|
|
ow, oh = img.size |
|
|
if (ow == target_width): |
|
|
return img |
|
|
w = target_width |
|
|
h = int(target_width * oh / ow) |
|
|
return img.resize((w, h), method) |
|
|
|
|
|
|
|
|
def __scale_shortside(img, target_width, method=Image.BICUBIC): |
|
|
ow, oh = img.size |
|
|
ss, ls = min(ow, oh), max(ow, oh) |
|
|
width_is_shorter = ow == ss |
|
|
if (ss == target_width): |
|
|
return img |
|
|
ls = int(target_width * ls / ss) |
|
|
nw, nh = (ss, ls) if width_is_shorter else (ls, ss) |
|
|
return img.resize((nw, nh), method) |
|
|
|
|
|
|
|
|
def __crop(img, pos, size): |
|
|
ow, oh = img.size |
|
|
x1, y1 = pos |
|
|
tw = th = size |
|
|
return img.crop((x1, y1, x1 + tw, y1 + th)) |
|
|
|
|
|
|
|
|
def __flip(img, flip): |
|
|
if flip: |
|
|
return img.transpose(Image.FLIP_LEFT_RIGHT) |
|
|
return img |
|
|
|