File size: 3,996 Bytes
98feea6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from PIL import Image
import torchvision.transforms.functional as TF
import torchvision.transforms as tf
from PIL import Image, ImageFile
import random
import math
from model import *
import torch
# import cv2
# cv2.setNumThreads(0)

ImageFile.LOAD_TRUNCATED_IMAGES = True


class base_dataset(Dataset):
    def __init__(self, data_dir, img_size, transforms=False, crop=False):
        imgs = sorted(os.listdir(data_dir + "/input"))
        self.input_imgs = [os.path.join(data_dir + "/input", name) for name in imgs]
        self.gt_imgs = [os.path.join(data_dir + "/gt", name) for name in imgs]
        self.transforms = transforms
        self.crop = crop
        self.img_size = img_size

    def __getitem__(self, index):
        inp_img_path = self.input_imgs[index]
        gt_img_path = self.gt_imgs[index]
        inp_img = Image.open(inp_img_path).convert("RGB")
        gt_img = Image.open(gt_img_path).convert("RGB")
        if self.transforms:
            inp_img = self.transforms(inp_img)
            gt_img = self.transforms(gt_img)

        if self.crop:
            inp_img, gt_img = self.crop_image(inp_img, gt_img)

        return inp_img, gt_img, inp_img_path

    def __len__(self):
        return len(self.gt_imgs)

    def crop_image(self, inp_img, gt_img):
        crop_h, crop_w = self.img_size
        i, j, h, w = tf.RandomCrop.get_params(
            inp_img, output_size=((crop_h, crop_w)))
        inp_img = TF.crop(inp_img, i, j, h, w)
        gt_img = TF.crop(gt_img, i, j, h, w)
        inp_img = TF.to_tensor(inp_img)
        gt_img = TF.to_tensor(gt_img)

        return inp_img, gt_img


class random_scale_dataset(Dataset):
    def __init__(self, data_dir, img_size, transforms=False, crop=False):
        imgs = sorted(os.listdir(data_dir + "/input"))
        self.input_imgs = [os.path.join(data_dir + "/input", name) for name in imgs]
        self.gt_imgs = [os.path.join(data_dir + "/gt", name) for name in imgs]
        self.transforms = transforms
        self.crop = crop
        self.img_size = img_size

    def __getitem__(self, index):
        inp_img_path = self.input_imgs[index]
        gt_img_path = self.gt_imgs[index]
        inp_img = Image.open(inp_img_path).convert("RGB")
        gt_img = Image.open(gt_img_path).convert("RGB")

        random_scale_factor = random.randrange(self.img_size[0] * 0.25, self.img_size[0], 8)
        down_h = down_w = random_scale_factor

        if self.transforms:
            inp_img = self.transforms(inp_img)
            gt_img = self.transforms(gt_img)
            return inp_img, gt_img, down_h, down_w, inp_img_path

        if self.crop:
            inp_img, gt_img = self.crop_image(inp_img, gt_img)
            return inp_img, gt_img, down_h, down_w, inp_img_path

    def __len__(self):
        return len(self.gt_imgs)

    def crop_image(self, inp_img, gt_img):
        crop_h, crop_w = self.img_size
        i, j, h, w = tf.RandomCrop.get_params(
            inp_img, output_size=((crop_h, crop_w)))
        inp_img = TF.crop(inp_img, i, j, h, w)
        gt_img = TF.crop(gt_img, i, j, h, w)
        inp_img = TF.to_tensor(inp_img)
        gt_img = TF.to_tensor(gt_img)

        return inp_img, gt_img


def get_loader(data_dir, img_size, transforms, crop_flag, batch_size, num_workers, shuffle, random_flag=False, inference_flag=False):
    if random_flag:
        dataset = random_scale_dataset(data_dir, img_size, transforms, crop_flag)
        dataloader = DataLoader(dataset, batch_size=batch_size,
                                shuffle=shuffle, num_workers=num_workers, pin_memory=True)
    else:
        dataset = base_dataset(data_dir, img_size, transforms, crop_flag)
        dataloader = DataLoader(dataset, batch_size=batch_size,
                                shuffle=shuffle, num_workers=num_workers, pin_memory=True)
    return dataloader