AMontiB
Your original commit message (now includes LFS pointer)
9c4b1c4
import os
import torch
from torchvision import datasets
import torchvision.transforms.v2 as Tv2
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import json
import bisect
def parse_dataset(settings):
gen_keys = {
'gan1':['StyleGAN'],
'gan2':['StyleGAN2'],
'gan3':['StyleGAN3'],
'sd15':['StableDiffusion1.5'],
'sd2':['StableDiffusion2'],
'sd3':['StableDiffusion3'],
'sdXL':['StableDiffusionXL'],
'flux':['FLUX.1'],
'realFFHQ':['FFHQ'],
'realFORLAB':['FORLAB']
}
gen_keys['all'] = [gen_keys[key][0] for key in gen_keys.keys()]
gen_keys['real'] = [gen_keys[key][0] for key in gen_keys.keys() if 'real' in key]
mod_keys = {
'pre': ['PreSocial'],
'fb': ['Facebook'],
'tl': ['Telegram'],
'tw': ['X'],
}
mod_keys['all'] = [mod_keys[key][0] for key in mod_keys.keys()]
mod_keys['shr'] = [mod_keys[key][0] for key in mod_keys.keys() if key in ['fb', 'tl', 'tw']]
need_real = (settings.task == 'train' and not len([data.split(':')[0] for data in settings.data_keys.split('&') if 'real' in data.split(':')[0]]))
assert not need_real, 'Train task without real data, this will not get handeled automatically, terminating'
dataset_list = []
for data in settings.data_keys.split('&'):
gen, mod = data.split(':')
dataset_list.append({'gen':gen_keys[gen], 'mod':mod_keys[mod]})
return dataset_list
class TrueFake_dataset(datasets.DatasetFolder):
def __init__(self, settings):
self.data_root = settings.data_root
self.split = settings.split
with open(settings.split_file, "r") as f:
split_list = sorted(json.load(f)[self.split])
dataset_list = parse_dataset(settings)
self.samples = []
self.info = []
for dict in dataset_list:
generators = dict['gen']
modifiers = dict['mod']
for mod in modifiers:
for dataset_root, dataset_dirs, dataset_files in os.walk(os.path.join(self.data_root, mod), topdown=True, followlinks=True):
if len(dataset_dirs):
continue
(label, gen, sub) = f'{dataset_root}/'.replace(os.path.join(self.data_root, mod) + os.sep, '').split(os.sep)[:3]
if gen in generators:
for filename in sorted(dataset_files):
if os.path.splitext(filename)[1].lower() in ['.png', '.jpg', '.jpeg']:
if self._in_list(split_list, os.path.join(gen, sub, os.path.splitext(filename)[0])):
self.samples.append(os.path.join(dataset_root, filename))
self.info.append((mod, label, gen, sub))
self.transform_start = Tv2.Compose(
[
Tv2.ToImage()
]
)
self.transform_end = Tv2.Compose(
[
Tv2.CenterCrop(1024) if self.split == 'test' and 'realFORLAB:pre' in settings.data_keys else Tv2.Identity(),
Tv2.CenterCrop(720) if self.split == 'test' and 'realFORLAB:fb' in settings.data_keys else Tv2.Identity(),
Tv2.CenterCrop(1200) if self.split == 'test' and 'realFORLAB:tw' in settings.data_keys else Tv2.Identity(),
Tv2.CenterCrop(800) if self.split == 'test' and 'realFORLAB:tl' in settings.data_keys else Tv2.Identity(),
Tv2.ToDtype(torch.float32, scale=True),
Tv2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
if self.split in ['train', 'val']:
self.transform_aug = {
'light': Tv2.Compose(
[
Tv2.RandomChoice([Tv2.RandomResizedCrop([300], (0.5, 1.5), (0.5, 2)), Tv2.RandomCrop([300])], p=[0.3, 0.7]),
Tv2.Compose([Tv2.RandomHorizontalFlip(p=0.5), Tv2.RandomVerticalFlip(p=0.5)]),
Tv2.RandomCrop(96, pad_if_needed=True) if self.split == 'train' else Tv2.Identity(),
]
),
'heavy': Tv2.Compose(
[
Tv2.RandomChoice([Tv2.RandomResizedCrop([300], (0.5, 1.5), (0.5, 2)), Tv2.RandomCrop([300])], p=[0.3, 0.7]),
Tv2.RandomApply([Tv2.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)], p=0.3),
Tv2.RandomApply([Tv2.GaussianBlur(kernel_size=11, sigma=(0.1,3))], p=0.3),
Tv2.RandomApply([Tv2.JPEG((65, 95))], p=0.3),
Tv2.Compose([Tv2.RandomHorizontalFlip(p=0.5), Tv2.RandomVerticalFlip(p=0.5)]),
Tv2.RandomCrop(96, pad_if_needed=True) if self.split == 'train' else Tv2.Identity(),
]
)
}
else:
self.transform_aug = None
print()
print(f'Transforms for {self.split}:')
print(self.transform_start)
if self.transform_aug:
print(self.transform_aug['light'])
print(self.transform_aug['heavy'])
print(self.transform_end)
print(f'Loaded {len(self.samples)} samples for {self.split}')
def _in_list(self, split, elem):
i = bisect.bisect_left(split, elem)
return i != len(split) and split[i] == elem
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
path = self.samples[index]
mod, label, gen, sub = self.info[index]
image = Image.open(path).convert('RGB')
sample = self.transform_start(image)
if self.transform_aug:
sample = self.transform_aug['heavy' if mod == 'PreSocial' else 'light'](sample)
sample = self.transform_end(sample)
target = 1.0 if label == 'Fake' else 0.0
return sample, target, path
def create_dataloader(settings, split=None):
if split == "train":
settings.split = 'train'
is_train=True
elif split == "val":
settings.split = 'val'
is_train=False
elif split == "test":
settings.split = 'test'
settings.batch_size = settings.batch_size//8
is_train=False
else:
raise ValueError(f"Unknown split {split}")
dataset = TrueFake_dataset(settings)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=settings.batch_size,
num_workers=int(settings.num_threads),
shuffle = is_train,
collate_fn=None,
)
return data_loader