photo-enhancer / src /envs /env_dataloader.py
zakaria-narjis's picture
add src and models
998f96a
from .image_dataset import FiveKDataset
from torch.utils.data import Dataset
import torch
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from .features_extractor import ResnetEncoder
class PhotoEnhancement:
"""
Encode dataset images
output : torch.Tensors of encoded and raw source/target images (3,H,W)
"""
def __init__(self,image_size,
mode = 'train',
resize=True,
augment_data=False,
use_txt_features=False,
pre_load_images=True,
device='cuda:0') -> None:
self.image_size = image_size
self.mode = mode
self.resize = resize
self.augment_data = augment_data
self.use_txt_features = use_txt_features
self.pre_load_images = pre_load_images
self.device = device
def generate_dataset(self):
return FiveKDataset(image_size=self.image_size,mode=self.mode,
resize=self.resize, augment_data=self.augment_data,
use_txt_features=self.use_txt_features,device=self.device,pre_load_images=self.pre_load_images)
def create_dataloaders(batch_size,image_size,use_txt_features=False,
train=True,augment_data=False,shuffle=True,resize=True,pre_encoding_device='cuda',pre_load_images=True):
if train:
train_dataset = PhotoEnhancement(image_size, mode='train', resize=resize,
augment_data=augment_data,
use_txt_features=use_txt_features,
device=pre_encoding_device,pre_load_images=pre_load_images)
train_dataset = train_dataset.generate_dataset()
dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle = shuffle)
else:
test_dataset = PhotoEnhancement(image_size, mode='test', resize=resize,
augment_data=augment_data,
use_txt_features=use_txt_features,
device=pre_encoding_device,pre_load_images=pre_load_images)
test_dataset = test_dataset.generate_dataset()
dataloader = DataLoader(test_dataset, batch_size=batch_size , shuffle = shuffle)
return dataloader