Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| from torchvision import transforms | |
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| import io | |
| import torch | |
| from .dct import DCT_base_Rec_Module | |
| import random | |
| try: | |
| from torchvision.transforms import InterpolationMode | |
| BICUBIC = InterpolationMode.BICUBIC | |
| except ImportError: | |
| BICUBIC = Image.BICUBIC | |
| from PIL import ImageFile | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| import kornia.augmentation as K | |
| Perturbations = K.container.ImageSequential( | |
| K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 3.0), p=0.1), | |
| K.RandomJPEG(jpeg_quality=(30, 100), p=0.1) | |
| ) | |
| transform_before = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Lambda(lambda x: Perturbations(x)[0]) | |
| ] | |
| ) | |
| transform_before_test = transforms.Compose([ | |
| transforms.ToTensor(), | |
| ] | |
| ) | |
| transform_train = transforms.Compose([ | |
| transforms.Resize([256, 256]), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] | |
| ) | |
| transform_test_normalize = transforms.Compose([ | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] | |
| ) | |
| class TrainDataset(Dataset): | |
| def __init__(self, is_train, args): | |
| root = args.data_path if is_train else args.eval_data_path | |
| self.data_list = [] | |
| if'GenImage' in root and root.split('/')[-1] != 'train': | |
| file_path = root | |
| if '0_real' not in os.listdir(file_path): | |
| for folder_name in os.listdir(file_path): | |
| assert (os.listdir(os.path.join(file_path, folder_name)) == ['0_real', '1_fake']) or (os.listdir(os.path.join(file_path, folder_name)) == ['1_fake', '0_real']) | |
| for image_path in os.listdir(os.path.join(file_path, folder_name, '0_real')): | |
| self.data_list.append({"image_path": os.path.join(file_path, folder_name, '0_real', image_path), "label" : 0}) | |
| for image_path in os.listdir(os.path.join(file_path, folder_name, '1_fake')): | |
| self.data_list.append({"image_path": os.path.join(file_path, folder_name, '1_fake', image_path), "label" : 1}) | |
| else: | |
| for image_path in os.listdir(os.path.join(file_path, '0_real')): | |
| self.data_list.append({"image_path": os.path.join(file_path, '0_real', image_path), "label" : 0}) | |
| for image_path in os.listdir(os.path.join(file_path, '1_fake')): | |
| self.data_list.append({"image_path": os.path.join(file_path, '1_fake', image_path), "label" : 1}) | |
| else: | |
| for filename in os.listdir(root): | |
| file_path = os.path.join(root, filename) | |
| if '0_real' not in os.listdir(file_path): | |
| for folder_name in os.listdir(file_path): | |
| assert (os.listdir(os.path.join(file_path, folder_name)) == ['0_real', '1_fake']) or (os.listdir(os.path.join(file_path, folder_name)) == ['1_fake', '0_real']) | |
| for image_path in os.listdir(os.path.join(file_path, folder_name, '0_real')): | |
| self.data_list.append({"image_path": os.path.join(file_path, folder_name, '0_real', image_path), "label" : 0}) | |
| for image_path in os.listdir(os.path.join(file_path, folder_name, '1_fake')): | |
| self.data_list.append({"image_path": os.path.join(file_path, folder_name, '1_fake', image_path), "label" : 1}) | |
| else: | |
| for image_path in os.listdir(os.path.join(file_path, '0_real')): | |
| self.data_list.append({"image_path": os.path.join(file_path, '0_real', image_path), "label" : 0}) | |
| for image_path in os.listdir(os.path.join(file_path, '1_fake')): | |
| self.data_list.append({"image_path": os.path.join(file_path, '1_fake', image_path), "label" : 1}) | |
| self.dct = DCT_base_Rec_Module() | |
| def __len__(self): | |
| return len(self.data_list) | |
| def __getitem__(self, index): | |
| sample = self.data_list[index] | |
| image_path, targets = sample['image_path'], sample['label'] | |
| try: | |
| image = Image.open(image_path).convert('RGB') | |
| except: | |
| print(f'image error: {image_path}') | |
| return self.__getitem__(random.randint(0, len(self.data_list) - 1)) | |
| image = transform_before(image) | |
| try: | |
| x_minmin, x_maxmax, x_minmin1, x_maxmax1 = self.dct(image) | |
| except: | |
| print(f'image error: {image_path}, c, h, w: {image.shape}') | |
| return self.__getitem__(random.randint(0, len(self.data_list) - 1)) | |
| x_0 = transform_train(image) | |
| x_minmin = transform_train(x_minmin) | |
| x_maxmax = transform_train(x_maxmax) | |
| x_minmin1 = transform_train(x_minmin1) | |
| x_maxmax1 = transform_train(x_maxmax1) | |
| return torch.stack([x_minmin, x_maxmax, x_minmin1, x_maxmax1, x_0], dim=0), torch.tensor(int(targets)) | |
| class TestDataset(Dataset): | |
| def __init__(self, is_train, args): | |
| root = args.data_path if is_train else args.eval_data_path | |
| self.data_list = [] | |
| file_path = root | |
| if '0_real' not in os.listdir(file_path): | |
| for folder_name in os.listdir(file_path): | |
| assert (os.listdir(os.path.join(file_path, folder_name)) == ['0_real', '1_fake']) or (os.listdir(os.path.join(file_path, folder_name)) == ['1_fake', '0_real']) | |
| for image_path in os.listdir(os.path.join(file_path, folder_name, '0_real')): | |
| self.data_list.append({"image_path": os.path.join(file_path, folder_name, '0_real', image_path), "label" : 0}) | |
| for image_path in os.listdir(os.path.join(file_path, folder_name, '1_fake')): | |
| self.data_list.append({"image_path": os.path.join(file_path, folder_name, '1_fake', image_path), "label" : 1}) | |
| else: | |
| for image_path in os.listdir(os.path.join(file_path, '0_real')): | |
| self.data_list.append({"image_path": os.path.join(file_path, '0_real', image_path), "label" : 0}) | |
| for image_path in os.listdir(os.path.join(file_path, '1_fake')): | |
| self.data_list.append({"image_path": os.path.join(file_path, '1_fake', image_path), "label" : 1}) | |
| self.dct = DCT_base_Rec_Module() | |
| def __len__(self): | |
| return len(self.data_list) | |
| def __getitem__(self, index): | |
| sample = self.data_list[index] | |
| image_path, targets = sample['image_path'], sample['label'] | |
| image = Image.open(image_path).convert('RGB') | |
| image = transform_before_test(image) | |
| # x_max, x_min, x_max_min, x_minmin = self.dct(image) | |
| x_minmin, x_maxmax, x_minmin1, x_maxmax1 = self.dct(image) | |
| x_0 = transform_train(image) # 上采样到256*256 | |
| x_minmin = transform_train(x_minmin) | |
| x_maxmax = transform_train(x_maxmax) | |
| x_minmin1 = transform_train(x_minmin1) | |
| x_maxmax1 = transform_train(x_maxmax1) | |
| return torch.stack([x_minmin, x_maxmax, x_minmin1, x_maxmax1, x_0], dim=0), torch.tensor(int(targets)) | |