|
|
import os |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.utils.data.dataset import Dataset |
|
|
from PIL import Image |
|
|
import torchvision.transforms.functional as TF |
|
|
import torchvision.transforms as tf |
|
|
from PIL import Image, ImageFile |
|
|
import random |
|
|
import math |
|
|
from model import * |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
|
|
|
|
class base_dataset(Dataset): |
|
|
def __init__(self, data_dir, img_size, transforms=False, crop=False): |
|
|
imgs = sorted(os.listdir(data_dir + "/input")) |
|
|
self.input_imgs = [os.path.join(data_dir + "/input", name) for name in imgs] |
|
|
self.gt_imgs = [os.path.join(data_dir + "/gt", name) for name in imgs] |
|
|
self.transforms = transforms |
|
|
self.crop = crop |
|
|
self.img_size = img_size |
|
|
|
|
|
def __getitem__(self, index): |
|
|
inp_img_path = self.input_imgs[index] |
|
|
gt_img_path = self.gt_imgs[index] |
|
|
inp_img = Image.open(inp_img_path).convert("RGB") |
|
|
gt_img = Image.open(gt_img_path).convert("RGB") |
|
|
if self.transforms: |
|
|
inp_img = self.transforms(inp_img) |
|
|
gt_img = self.transforms(gt_img) |
|
|
|
|
|
if self.crop: |
|
|
inp_img, gt_img = self.crop_image(inp_img, gt_img) |
|
|
|
|
|
return inp_img, gt_img, inp_img_path |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.gt_imgs) |
|
|
|
|
|
def crop_image(self, inp_img, gt_img): |
|
|
crop_h, crop_w = self.img_size |
|
|
i, j, h, w = tf.RandomCrop.get_params( |
|
|
inp_img, output_size=((crop_h, crop_w))) |
|
|
inp_img = TF.crop(inp_img, i, j, h, w) |
|
|
gt_img = TF.crop(gt_img, i, j, h, w) |
|
|
inp_img = TF.to_tensor(inp_img) |
|
|
gt_img = TF.to_tensor(gt_img) |
|
|
|
|
|
return inp_img, gt_img |
|
|
|
|
|
|
|
|
class random_scale_dataset(Dataset): |
|
|
def __init__(self, data_dir, img_size, transforms=False, crop=False): |
|
|
imgs = sorted(os.listdir(data_dir + "/input")) |
|
|
self.input_imgs = [os.path.join(data_dir + "/input", name) for name in imgs] |
|
|
self.gt_imgs = [os.path.join(data_dir + "/gt", name) for name in imgs] |
|
|
self.transforms = transforms |
|
|
self.crop = crop |
|
|
self.img_size = img_size |
|
|
|
|
|
def __getitem__(self, index): |
|
|
inp_img_path = self.input_imgs[index] |
|
|
gt_img_path = self.gt_imgs[index] |
|
|
inp_img = Image.open(inp_img_path).convert("RGB") |
|
|
gt_img = Image.open(gt_img_path).convert("RGB") |
|
|
|
|
|
random_scale_factor = random.randrange(self.img_size[0] * 0.25, self.img_size[0], 8) |
|
|
down_h = down_w = random_scale_factor |
|
|
|
|
|
if self.transforms: |
|
|
inp_img = self.transforms(inp_img) |
|
|
gt_img = self.transforms(gt_img) |
|
|
return inp_img, gt_img, down_h, down_w, inp_img_path |
|
|
|
|
|
if self.crop: |
|
|
inp_img, gt_img = self.crop_image(inp_img, gt_img) |
|
|
return inp_img, gt_img, down_h, down_w, inp_img_path |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.gt_imgs) |
|
|
|
|
|
def crop_image(self, inp_img, gt_img): |
|
|
crop_h, crop_w = self.img_size |
|
|
i, j, h, w = tf.RandomCrop.get_params( |
|
|
inp_img, output_size=((crop_h, crop_w))) |
|
|
inp_img = TF.crop(inp_img, i, j, h, w) |
|
|
gt_img = TF.crop(gt_img, i, j, h, w) |
|
|
inp_img = TF.to_tensor(inp_img) |
|
|
gt_img = TF.to_tensor(gt_img) |
|
|
|
|
|
return inp_img, gt_img |
|
|
|
|
|
|
|
|
def get_loader(data_dir, img_size, transforms, crop_flag, batch_size, num_workers, shuffle, random_flag=False, inference_flag=False): |
|
|
if random_flag: |
|
|
dataset = random_scale_dataset(data_dir, img_size, transforms, crop_flag) |
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, |
|
|
shuffle=shuffle, num_workers=num_workers, pin_memory=True) |
|
|
else: |
|
|
dataset = base_dataset(data_dir, img_size, transforms, crop_flag) |
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, |
|
|
shuffle=shuffle, num_workers=num_workers, pin_memory=True) |
|
|
return dataloader |
|
|
|