Spaces:
Running
Running
| import os | |
| import torch | |
| from torchvision import datasets | |
| import torchvision.transforms.v2 as Tv2 | |
| from PIL import Image, ImageFile | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| import json | |
| import bisect | |
| def parse_dataset(settings): | |
| gen_keys = { | |
| 'gan1':['StyleGAN'], | |
| 'gan2':['StyleGAN2'], | |
| 'gan3':['StyleGAN3'], | |
| 'sd15':['StableDiffusion1.5'], | |
| 'sd2':['StableDiffusion2'], | |
| 'sd3':['StableDiffusion3'], | |
| 'sdXL':['StableDiffusionXL'], | |
| 'flux':['FLUX.1'], | |
| 'realFFHQ':['FFHQ'], | |
| 'realFORLAB':['FORLAB'] | |
| } | |
| gen_keys['all'] = [gen_keys[key][0] for key in gen_keys.keys()] | |
| gen_keys['real'] = [gen_keys[key][0] for key in gen_keys.keys() if 'real' in key] | |
| mod_keys = { | |
| 'pre': ['PreSocial'], | |
| 'fb': ['Facebook'], | |
| 'tl': ['Telegram'], | |
| 'tw': ['X'], | |
| } | |
| mod_keys['all'] = [mod_keys[key][0] for key in mod_keys.keys()] | |
| mod_keys['shr'] = [mod_keys[key][0] for key in mod_keys.keys() if key in ['fb', 'tl', 'tw']] | |
| need_real = (settings.task == 'train' and not len([data.split(':')[0] for data in settings.data_keys.split('&') if 'real' in data.split(':')[0]])) | |
| assert not need_real, 'Train task without real data, this will not get handeled automatically, terminating' | |
| dataset_list = [] | |
| for data in settings.data_keys.split('&'): | |
| gen, mod = data.split(':') | |
| dataset_list.append({'gen':gen_keys[gen], 'mod':mod_keys[mod]}) | |
| return dataset_list | |
| class TrueFake_dataset(datasets.DatasetFolder): | |
| def __init__(self, settings): | |
| self.data_root = settings.data_root | |
| self.split = settings.split | |
| with open(settings.split_file, "r") as f: | |
| split_list = sorted(json.load(f)[self.split]) | |
| dataset_list = parse_dataset(settings) | |
| self.samples = [] | |
| self.info = [] | |
| for dict in dataset_list: | |
| generators = dict['gen'] | |
| modifiers = dict['mod'] | |
| for mod in modifiers: | |
| for dataset_root, dataset_dirs, dataset_files in os.walk(os.path.join(self.data_root, mod), topdown=True, followlinks=True): | |
| if len(dataset_dirs): | |
| continue | |
| (label, gen, sub) = f'{dataset_root}/'.replace(os.path.join(self.data_root, mod) + os.sep, '').split(os.sep)[:3] | |
| if gen in generators: | |
| for filename in sorted(dataset_files): | |
| if os.path.splitext(filename)[1].lower() in ['.png', '.jpg', '.jpeg']: | |
| if self._in_list(split_list, os.path.join(gen, sub, os.path.splitext(filename)[0])): | |
| self.samples.append(os.path.join(dataset_root, filename)) | |
| self.info.append((mod, label, gen, sub)) | |
| self.transform_start = Tv2.Compose( | |
| [ | |
| Tv2.ToImage() | |
| ] | |
| ) | |
| self.transform_end = Tv2.Compose( | |
| [ | |
| Tv2.CenterCrop(1024) if self.split == 'test' and 'realFORLAB:pre' in settings.data_keys else Tv2.Identity(), | |
| Tv2.CenterCrop(720) if self.split == 'test' and 'realFORLAB:fb' in settings.data_keys else Tv2.Identity(), | |
| Tv2.CenterCrop(1200) if self.split == 'test' and 'realFORLAB:tw' in settings.data_keys else Tv2.Identity(), | |
| Tv2.CenterCrop(800) if self.split == 'test' and 'realFORLAB:tl' in settings.data_keys else Tv2.Identity(), | |
| Tv2.ToDtype(torch.float32, scale=True), | |
| Tv2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| if self.split in ['train', 'val']: | |
| self.transform_aug = { | |
| 'light': Tv2.Compose( | |
| [ | |
| Tv2.RandomChoice([Tv2.RandomResizedCrop([300], (0.5, 1.5), (0.5, 2)), Tv2.RandomCrop([300])], p=[0.3, 0.7]), | |
| Tv2.Compose([Tv2.RandomHorizontalFlip(p=0.5), Tv2.RandomVerticalFlip(p=0.5)]), | |
| Tv2.RandomCrop(96, pad_if_needed=True) if self.split == 'train' else Tv2.Identity(), | |
| ] | |
| ), | |
| 'heavy': Tv2.Compose( | |
| [ | |
| Tv2.RandomChoice([Tv2.RandomResizedCrop([300], (0.5, 1.5), (0.5, 2)), Tv2.RandomCrop([300])], p=[0.3, 0.7]), | |
| Tv2.RandomApply([Tv2.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)], p=0.3), | |
| Tv2.RandomApply([Tv2.GaussianBlur(kernel_size=11, sigma=(0.1,3))], p=0.3), | |
| Tv2.RandomApply([Tv2.JPEG((65, 95))], p=0.3), | |
| Tv2.Compose([Tv2.RandomHorizontalFlip(p=0.5), Tv2.RandomVerticalFlip(p=0.5)]), | |
| Tv2.RandomCrop(96, pad_if_needed=True) if self.split == 'train' else Tv2.Identity(), | |
| ] | |
| ) | |
| } | |
| else: | |
| self.transform_aug = None | |
| print() | |
| print(f'Transforms for {self.split}:') | |
| print(self.transform_start) | |
| if self.transform_aug: | |
| print(self.transform_aug['light']) | |
| print(self.transform_aug['heavy']) | |
| print(self.transform_end) | |
| print(f'Loaded {len(self.samples)} samples for {self.split}') | |
| def _in_list(self, split, elem): | |
| i = bisect.bisect_left(split, elem) | |
| return i != len(split) and split[i] == elem | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, index): | |
| path = self.samples[index] | |
| mod, label, gen, sub = self.info[index] | |
| image = Image.open(path).convert('RGB') | |
| sample = self.transform_start(image) | |
| if self.transform_aug: | |
| sample = self.transform_aug['heavy' if mod == 'PreSocial' else 'light'](sample) | |
| sample = self.transform_end(sample) | |
| target = 1.0 if label == 'Fake' else 0.0 | |
| return sample, target, path | |
| def create_dataloader(settings, split=None): | |
| if split == "train": | |
| settings.split = 'train' | |
| is_train=True | |
| elif split == "val": | |
| settings.split = 'val' | |
| is_train=False | |
| elif split == "test": | |
| settings.split = 'test' | |
| settings.batch_size = settings.batch_size//8 | |
| is_train=False | |
| else: | |
| raise ValueError(f"Unknown split {split}") | |
| dataset = TrueFake_dataset(settings) | |
| data_loader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=settings.batch_size, | |
| num_workers=int(settings.num_threads), | |
| shuffle = is_train, | |
| collate_fn=None, | |
| ) | |
| return data_loader | |