AMontiB
Your original commit message (now includes LFS pointer)
9c4b1c4
import torch
import numpy as np
from torch.utils.data.sampler import WeightedRandomSampler
from .datasets import dataset_folder
from torchvision.datasets import DatasetFolder
import json
import bisect
from PIL import Image
import torchvision.transforms.v2 as Tv2
'''
def get_dataset(opt):
dset_lst = []
for cls in opt.classes:
root = opt.dataroot + '/' + cls
dset = dataset_folder(opt, root)
dset_lst.append(dset)
return torch.utils.data.ConcatDataset(dset_lst)
'''
import os
# def get_dataset(opt):
# classes = os.listdir(opt.dataroot) if len(opt.classes) == 0 else opt.classes
# if '0_real' not in classes or '1_fake' not in classes:
# dset_lst = []
# for cls in classes:
# root = opt.dataroot + '/' + cls
# dset = dataset_folder(opt, root)
# dset_lst.append(dset)
# return torch.utils.data.ConcatDataset(dset_lst)
# return dataset_folder(opt, opt.dataroot)
# def get_bal_sampler(dataset):
# targets = []
# for d in dataset.datasets:
# targets.extend(d.targets)
# ratio = np.bincount(targets)
# w = 1. / torch.tensor(ratio, dtype=torch.float)
# sample_weights = w[targets]
# sampler = WeightedRandomSampler(weights=sample_weights,
# num_samples=len(sample_weights))
# return sampler
# def create_dataloader(opt):
# shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False
# dataset = get_dataset(opt)
# sampler = get_bal_sampler(dataset) if opt.class_bal else None
# data_loader = torch.utils.data.DataLoader(dataset,
# batch_size=opt.batch_size,
# shuffle=shuffle,
# sampler=sampler,
# num_workers=int(opt.num_threads))
# return data_loader
def parse_dataset(settings):
gen_keys = {
'gan1':['StyleGAN'],
'gan2':['StyleGAN2'],
'gan3':['StyleGAN3'],
'sd15':['StableDiffusion1.5'],
'sd2':['StableDiffusion2'],
'sd3':['StableDiffusion3'],
'sdXL':['StableDiffusionXL'],
'flux':['FLUX.1'],
'realFFHQ':['FFHQ'],
'realFORLAB':['FORLAB']
}
gen_keys['all'] = [gen_keys[key][0] for key in gen_keys.keys()]
# gen_keys['gan'] = [gen_keys[key][0] for key in gen_keys.keys() if 'gan' in key]
# gen_keys['sd'] = [gen_keys[key][0] for key in gen_keys.keys() if 'sd' in key]
gen_keys['real'] = [gen_keys[key][0] for key in gen_keys.keys() if 'real' in key]
mod_keys = {
'pre': ['PreSocial'],
'fb': ['Facebook'],
'tl': ['Telegram'],
'tw': ['X'],
}
mod_keys['all'] = [mod_keys[key][0] for key in mod_keys.keys()]
mod_keys['shr'] = [mod_keys[key][0] for key in mod_keys.keys() if key in ['fb', 'tl', 'tw']]
need_real = (settings.task == 'train' and not len([data.split(':')[0] for data in settings.data_keys.split('&') if 'real' in data.split(':')[0]]))
assert not need_real, 'Train task without real data, this will not get handeled automatically, terminating'
dataset_list = []
for data in settings.data_keys.split('&'):
gen, mod = data.split(':')
dataset_list.append({'gen':gen_keys[gen], 'mod':mod_keys[mod]})
return dataset_list
class TrueFake_dataset(DatasetFolder):
def __init__(self, settings):
self.data_root = settings.data_root
self.split = settings.split
with open(settings.split_file, "r") as f:
split_list = sorted(json.load(f)[self.split])
dataset_list = parse_dataset(settings)
self.samples = []
self.info = []
for dict in dataset_list:
generators = dict['gen']
modifiers = dict['mod']
for mod in modifiers:
for dataset_root, dataset_dirs, dataset_files in os.walk(os.path.join(self.data_root, mod), topdown=True, followlinks=True):
if len(dataset_dirs):
continue
(label, gen, sub) = f'{dataset_root}/'.replace(os.path.join(self.data_root, mod) + os.sep, '').split(os.sep)[:3]
if gen in generators:
for filename in sorted(dataset_files):
if os.path.splitext(filename)[1].lower() in ['.png', '.jpg', '.jpeg']:
if self._in_list(split_list, os.path.join(gen, sub, os.path.splitext(filename)[0])):
self.samples.append(os.path.join(dataset_root, filename))
self.info.append((mod, label, gen, sub))
if settings.isTrain:
crop_func = Tv2.RandomCrop(settings.cropSize)
elif settings.no_crop:
crop_func = Tv2.Identity()
else:
crop_func = Tv2.CenterCrop(settings.cropSize)
if settings.isTrain and not settings.no_flip:
flip_func = Tv2.RandomHorizontalFlip()
else:
flip_func = Tv2.Identity()
if not settings.isTrain and settings.no_resize:
rz_func = Tv2.Identity()
else:
rz_func = Tv2.Resize((settings.loadSize, settings.loadSize))
self.transform = Tv2.Compose([
rz_func,
crop_func,
flip_func,
Tv2.ToTensor(),
Tv2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def _in_list(self, split, elem):
i = bisect.bisect_left(split, elem)
return i != len(split) and split[i] == elem
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
path = self.samples[index]
mod, label, gen, sub = self.info[index]
image = Image.open(path).convert('RGB')
sample = self.transform(image)
target = 1.0 if label == 'Fake' else 0.0
return sample, target, path
def create_dataloader(settings, split=None):
if split == "train":
settings.split = 'train'
is_train=True
elif split == "val":
settings.split = 'val'
settings.batch_size = settings.batch_size//4
is_train=False
elif split == "test":
settings.split = 'test'
settings.batch_size = settings.batch_size//4
is_train=False
else:
raise ValueError(f"Unknown split {split}")
dataset = TrueFake_dataset(settings)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=settings.batch_size,
num_workers=int(settings.num_threads),
shuffle = is_train,
collate_fn=None,
)
return data_loader