Spaces:
Sleeping
Sleeping
| from .image_dataset import FiveKDataset | |
| from torch.utils.data import Dataset | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from tqdm.notebook import tqdm | |
| from .features_extractor import ResnetEncoder | |
| class PhotoEnhancement: | |
| """ | |
| Encode dataset images | |
| output : torch.Tensors of encoded and raw source/target images (3,H,W) | |
| """ | |
| def __init__(self,image_size, | |
| mode = 'train', | |
| resize=True, | |
| augment_data=False, | |
| use_txt_features=False, | |
| pre_load_images=True, | |
| device='cuda:0') -> None: | |
| self.image_size = image_size | |
| self.mode = mode | |
| self.resize = resize | |
| self.augment_data = augment_data | |
| self.use_txt_features = use_txt_features | |
| self.pre_load_images = pre_load_images | |
| self.device = device | |
| def generate_dataset(self): | |
| return FiveKDataset(image_size=self.image_size,mode=self.mode, | |
| resize=self.resize, augment_data=self.augment_data, | |
| use_txt_features=self.use_txt_features,device=self.device,pre_load_images=self.pre_load_images) | |
| def create_dataloaders(batch_size,image_size,use_txt_features=False, | |
| train=True,augment_data=False,shuffle=True,resize=True,pre_encoding_device='cuda',pre_load_images=True): | |
| if train: | |
| train_dataset = PhotoEnhancement(image_size, mode='train', resize=resize, | |
| augment_data=augment_data, | |
| use_txt_features=use_txt_features, | |
| device=pre_encoding_device,pre_load_images=pre_load_images) | |
| train_dataset = train_dataset.generate_dataset() | |
| dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle = shuffle) | |
| else: | |
| test_dataset = PhotoEnhancement(image_size, mode='test', resize=resize, | |
| augment_data=augment_data, | |
| use_txt_features=use_txt_features, | |
| device=pre_encoding_device,pre_load_images=pre_load_images) | |
| test_dataset = test_dataset.generate_dataset() | |
| dataloader = DataLoader(test_dataset, batch_size=batch_size , shuffle = shuffle) | |
| return dataloader |