File size: 2,242 Bytes
998f96a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from .image_dataset import FiveKDataset
from torch.utils.data import Dataset
import torch
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from .features_extractor import ResnetEncoder


class PhotoEnhancement:
    """
        Encode dataset images
        output : torch.Tensors of encoded and raw source/target images (3,H,W)
    """
    def __init__(self,image_size,
                 mode = 'train', 
                 resize=True,
                 augment_data=False,
                 use_txt_features=False,
                 pre_load_images=True,
                 device='cuda:0') -> None:
        self.image_size = image_size
        self.mode = mode
        self.resize = resize
        self.augment_data = augment_data 
        self.use_txt_features = use_txt_features
        self.pre_load_images = pre_load_images
        self.device = device
        
    def generate_dataset(self):
        return FiveKDataset(image_size=self.image_size,mode=self.mode, 
                            resize=self.resize, augment_data=self.augment_data,
                            use_txt_features=self.use_txt_features,device=self.device,pre_load_images=self.pre_load_images)

def create_dataloaders(batch_size,image_size,use_txt_features=False,
                       train=True,augment_data=False,shuffle=True,resize=True,pre_encoding_device='cuda',pre_load_images=True):
    if train:    
        train_dataset = PhotoEnhancement(image_size, mode='train', resize=resize, 
                 augment_data=augment_data, 
                 use_txt_features=use_txt_features,
                 device=pre_encoding_device,pre_load_images=pre_load_images)
        train_dataset = train_dataset.generate_dataset()
        dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle = shuffle)
    else: 
        test_dataset = PhotoEnhancement(image_size, mode='test', resize=resize, 
                 augment_data=augment_data, 
                 use_txt_features=use_txt_features,
                 device=pre_encoding_device,pre_load_images=pre_load_images)
        test_dataset = test_dataset.generate_dataset()
        dataloader = DataLoader(test_dataset, batch_size=batch_size , shuffle = shuffle)

    return dataloader