File size: 3,996 Bytes
98feea6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import os
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from PIL import Image
import torchvision.transforms.functional as TF
import torchvision.transforms as tf
from PIL import Image, ImageFile
import random
import math
from model import *
import torch
# import cv2
# cv2.setNumThreads(0)
ImageFile.LOAD_TRUNCATED_IMAGES = True
class base_dataset(Dataset):
def __init__(self, data_dir, img_size, transforms=False, crop=False):
imgs = sorted(os.listdir(data_dir + "/input"))
self.input_imgs = [os.path.join(data_dir + "/input", name) for name in imgs]
self.gt_imgs = [os.path.join(data_dir + "/gt", name) for name in imgs]
self.transforms = transforms
self.crop = crop
self.img_size = img_size
def __getitem__(self, index):
inp_img_path = self.input_imgs[index]
gt_img_path = self.gt_imgs[index]
inp_img = Image.open(inp_img_path).convert("RGB")
gt_img = Image.open(gt_img_path).convert("RGB")
if self.transforms:
inp_img = self.transforms(inp_img)
gt_img = self.transforms(gt_img)
if self.crop:
inp_img, gt_img = self.crop_image(inp_img, gt_img)
return inp_img, gt_img, inp_img_path
def __len__(self):
return len(self.gt_imgs)
def crop_image(self, inp_img, gt_img):
crop_h, crop_w = self.img_size
i, j, h, w = tf.RandomCrop.get_params(
inp_img, output_size=((crop_h, crop_w)))
inp_img = TF.crop(inp_img, i, j, h, w)
gt_img = TF.crop(gt_img, i, j, h, w)
inp_img = TF.to_tensor(inp_img)
gt_img = TF.to_tensor(gt_img)
return inp_img, gt_img
class random_scale_dataset(Dataset):
def __init__(self, data_dir, img_size, transforms=False, crop=False):
imgs = sorted(os.listdir(data_dir + "/input"))
self.input_imgs = [os.path.join(data_dir + "/input", name) for name in imgs]
self.gt_imgs = [os.path.join(data_dir + "/gt", name) for name in imgs]
self.transforms = transforms
self.crop = crop
self.img_size = img_size
def __getitem__(self, index):
inp_img_path = self.input_imgs[index]
gt_img_path = self.gt_imgs[index]
inp_img = Image.open(inp_img_path).convert("RGB")
gt_img = Image.open(gt_img_path).convert("RGB")
random_scale_factor = random.randrange(self.img_size[0] * 0.25, self.img_size[0], 8)
down_h = down_w = random_scale_factor
if self.transforms:
inp_img = self.transforms(inp_img)
gt_img = self.transforms(gt_img)
return inp_img, gt_img, down_h, down_w, inp_img_path
if self.crop:
inp_img, gt_img = self.crop_image(inp_img, gt_img)
return inp_img, gt_img, down_h, down_w, inp_img_path
def __len__(self):
return len(self.gt_imgs)
def crop_image(self, inp_img, gt_img):
crop_h, crop_w = self.img_size
i, j, h, w = tf.RandomCrop.get_params(
inp_img, output_size=((crop_h, crop_w)))
inp_img = TF.crop(inp_img, i, j, h, w)
gt_img = TF.crop(gt_img, i, j, h, w)
inp_img = TF.to_tensor(inp_img)
gt_img = TF.to_tensor(gt_img)
return inp_img, gt_img
def get_loader(data_dir, img_size, transforms, crop_flag, batch_size, num_workers, shuffle, random_flag=False, inference_flag=False):
if random_flag:
dataset = random_scale_dataset(data_dir, img_size, transforms, crop_flag)
dataloader = DataLoader(dataset, batch_size=batch_size,
shuffle=shuffle, num_workers=num_workers, pin_memory=True)
else:
dataset = base_dataset(data_dir, img_size, transforms, crop_flag)
dataloader = DataLoader(dataset, batch_size=batch_size,
shuffle=shuffle, num_workers=num_workers, pin_memory=True)
return dataloader
|