Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from torch.utils.data.sampler import WeightedRandomSampler | |
| from .datasets import dataset_folder | |
| from torchvision.datasets import DatasetFolder | |
| import json | |
| import bisect | |
| from PIL import Image | |
| import torchvision.transforms.v2 as Tv2 | |
| ''' | |
| def get_dataset(opt): | |
| dset_lst = [] | |
| for cls in opt.classes: | |
| root = opt.dataroot + '/' + cls | |
| dset = dataset_folder(opt, root) | |
| dset_lst.append(dset) | |
| return torch.utils.data.ConcatDataset(dset_lst) | |
| ''' | |
| import os | |
| # def get_dataset(opt): | |
| # classes = os.listdir(opt.dataroot) if len(opt.classes) == 0 else opt.classes | |
| # if '0_real' not in classes or '1_fake' not in classes: | |
| # dset_lst = [] | |
| # for cls in classes: | |
| # root = opt.dataroot + '/' + cls | |
| # dset = dataset_folder(opt, root) | |
| # dset_lst.append(dset) | |
| # return torch.utils.data.ConcatDataset(dset_lst) | |
| # return dataset_folder(opt, opt.dataroot) | |
| # def get_bal_sampler(dataset): | |
| # targets = [] | |
| # for d in dataset.datasets: | |
| # targets.extend(d.targets) | |
| # ratio = np.bincount(targets) | |
| # w = 1. / torch.tensor(ratio, dtype=torch.float) | |
| # sample_weights = w[targets] | |
| # sampler = WeightedRandomSampler(weights=sample_weights, | |
| # num_samples=len(sample_weights)) | |
| # return sampler | |
| # def create_dataloader(opt): | |
| # shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False | |
| # dataset = get_dataset(opt) | |
| # sampler = get_bal_sampler(dataset) if opt.class_bal else None | |
| # data_loader = torch.utils.data.DataLoader(dataset, | |
| # batch_size=opt.batch_size, | |
| # shuffle=shuffle, | |
| # sampler=sampler, | |
| # num_workers=int(opt.num_threads)) | |
| # return data_loader | |
| def parse_dataset(settings): | |
| gen_keys = { | |
| 'gan1':['StyleGAN'], | |
| 'gan2':['StyleGAN2'], | |
| 'gan3':['StyleGAN3'], | |
| 'sd15':['StableDiffusion1.5'], | |
| 'sd2':['StableDiffusion2'], | |
| 'sd3':['StableDiffusion3'], | |
| 'sdXL':['StableDiffusionXL'], | |
| 'flux':['FLUX.1'], | |
| 'realFFHQ':['FFHQ'], | |
| 'realFORLAB':['FORLAB'] | |
| } | |
| gen_keys['all'] = [gen_keys[key][0] for key in gen_keys.keys()] | |
| # gen_keys['gan'] = [gen_keys[key][0] for key in gen_keys.keys() if 'gan' in key] | |
| # gen_keys['sd'] = [gen_keys[key][0] for key in gen_keys.keys() if 'sd' in key] | |
| gen_keys['real'] = [gen_keys[key][0] for key in gen_keys.keys() if 'real' in key] | |
| mod_keys = { | |
| 'pre': ['PreSocial'], | |
| 'fb': ['Facebook'], | |
| 'tl': ['Telegram'], | |
| 'tw': ['X'], | |
| } | |
| mod_keys['all'] = [mod_keys[key][0] for key in mod_keys.keys()] | |
| mod_keys['shr'] = [mod_keys[key][0] for key in mod_keys.keys() if key in ['fb', 'tl', 'tw']] | |
| need_real = (settings.task == 'train' and not len([data.split(':')[0] for data in settings.data_keys.split('&') if 'real' in data.split(':')[0]])) | |
| assert not need_real, 'Train task without real data, this will not get handeled automatically, terminating' | |
| dataset_list = [] | |
| for data in settings.data_keys.split('&'): | |
| gen, mod = data.split(':') | |
| dataset_list.append({'gen':gen_keys[gen], 'mod':mod_keys[mod]}) | |
| return dataset_list | |
| class TrueFake_dataset(DatasetFolder): | |
| def __init__(self, settings): | |
| self.data_root = settings.data_root | |
| self.split = settings.split | |
| with open(settings.split_file, "r") as f: | |
| split_list = sorted(json.load(f)[self.split]) | |
| dataset_list = parse_dataset(settings) | |
| self.samples = [] | |
| self.info = [] | |
| for dict in dataset_list: | |
| generators = dict['gen'] | |
| modifiers = dict['mod'] | |
| for mod in modifiers: | |
| for dataset_root, dataset_dirs, dataset_files in os.walk(os.path.join(self.data_root, mod), topdown=True, followlinks=True): | |
| if len(dataset_dirs): | |
| continue | |
| (label, gen, sub) = f'{dataset_root}/'.replace(os.path.join(self.data_root, mod) + os.sep, '').split(os.sep)[:3] | |
| if gen in generators: | |
| for filename in sorted(dataset_files): | |
| if os.path.splitext(filename)[1].lower() in ['.png', '.jpg', '.jpeg']: | |
| if self._in_list(split_list, os.path.join(gen, sub, os.path.splitext(filename)[0])): | |
| self.samples.append(os.path.join(dataset_root, filename)) | |
| self.info.append((mod, label, gen, sub)) | |
| if settings.isTrain: | |
| crop_func = Tv2.RandomCrop(settings.cropSize) | |
| elif settings.no_crop: | |
| crop_func = Tv2.Identity() | |
| else: | |
| crop_func = Tv2.CenterCrop(settings.cropSize) | |
| if settings.isTrain and not settings.no_flip: | |
| flip_func = Tv2.RandomHorizontalFlip() | |
| else: | |
| flip_func = Tv2.Identity() | |
| if not settings.isTrain and settings.no_resize: | |
| rz_func = Tv2.Identity() | |
| else: | |
| rz_func = Tv2.Resize((settings.loadSize, settings.loadSize)) | |
| self.transform = Tv2.Compose([ | |
| rz_func, | |
| crop_func, | |
| flip_func, | |
| Tv2.ToTensor(), | |
| Tv2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def _in_list(self, split, elem): | |
| i = bisect.bisect_left(split, elem) | |
| return i != len(split) and split[i] == elem | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, index): | |
| path = self.samples[index] | |
| mod, label, gen, sub = self.info[index] | |
| image = Image.open(path).convert('RGB') | |
| sample = self.transform(image) | |
| target = 1.0 if label == 'Fake' else 0.0 | |
| return sample, target, path | |
| def create_dataloader(settings, split=None): | |
| if split == "train": | |
| settings.split = 'train' | |
| is_train=True | |
| elif split == "val": | |
| settings.split = 'val' | |
| settings.batch_size = settings.batch_size//4 | |
| is_train=False | |
| elif split == "test": | |
| settings.split = 'test' | |
| settings.batch_size = settings.batch_size//4 | |
| is_train=False | |
| else: | |
| raise ValueError(f"Unknown split {split}") | |
| dataset = TrueFake_dataset(settings) | |
| data_loader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=settings.batch_size, | |
| num_workers=int(settings.num_threads), | |
| shuffle = is_train, | |
| collate_fn=None, | |
| ) | |
| return data_loader | |