Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from os.path import join | |
| from dataset import AbstractDataset | |
| SPLITS = ["train", "test"] | |
| class WildDeepfake(AbstractDataset): | |
| """ | |
| Wild Deepfake Dataset proposed in "WildDeepfake: A Challenging Real-World Dataset for Deepfake Detection" | |
| """ | |
| def __init__(self, cfg, seed=2022, transforms=None, transform=None, target_transform=None): | |
| # pre-check | |
| if cfg['split'] not in SPLITS: | |
| raise ValueError(f"split should be one of {SPLITS}, but found {cfg['split']}.") | |
| super(WildDeepfake, self).__init__(cfg, seed, transforms, transform, target_transform) | |
| print(f"Loading data from 'WildDeepfake' of split '{cfg['split']}'" | |
| f"\nPlease wait patiently...") | |
| self.categories = ['original', 'fake'] | |
| self.root = cfg['root'] | |
| self.num_train = cfg.get('num_image_train', None) | |
| self.num_test = cfg.get('num_image_test', None) | |
| self.images, self.targets = self.__get_images() | |
| print(f"Data from 'WildDeepfake' loaded.") | |
| print(f"Dataset contains {len(self.images)} images.\n") | |
| def __get_images(self): | |
| if self.split == 'train': | |
| num = self.num_train | |
| elif self.split == 'test': | |
| num = self.num_test | |
| else: | |
| num = None | |
| real_images = torch.load(join(self.root, self.split, "real.pickle")) | |
| if num is not None: | |
| real_images = np.random.choice(real_images, num // 3, replace=False) | |
| real_tgts = [torch.tensor(0)] * len(real_images) | |
| print(f"real: {len(real_tgts)}") | |
| fake_images = torch.load(join(self.root, self.split, "fake.pickle")) | |
| if num is not None: | |
| fake_images = np.random.choice(fake_images, num - num // 3, replace=False) | |
| fake_tgts = [torch.tensor(1)] * len(fake_images) | |
| print(f"fake: {len(fake_tgts)}") | |
| return real_images + fake_images, real_tgts + fake_tgts | |
| def __getitem__(self, index): | |
| path = join(self.root, self.split, self.images[index]) | |
| tgt = self.targets[index] | |
| return path, tgt | |
| if __name__ == '__main__': | |
| import yaml | |
| config_path = "../config/dataset/wilddeepfake.yml" | |
| with open(config_path) as config_file: | |
| config = yaml.load(config_file, Loader=yaml.FullLoader) | |
| config = config["train_cfg"] | |
| # config = config["test_cfg"] | |
| def run_dataset(): | |
| dataset = WildDeepfake(config) | |
| print(f"dataset: {len(dataset)}") | |
| for i, _ in enumerate(dataset): | |
| path, target = _ | |
| print(f"path: {path}, target: {target}") | |
| if i >= 9: | |
| break | |
| def run_dataloader(display_samples=False): | |
| from torch.utils import data | |
| import matplotlib.pyplot as plt | |
| dataset = WildDeepfake(config) | |
| dataloader = data.DataLoader(dataset, batch_size=8, shuffle=True) | |
| print(f"dataset: {len(dataset)}") | |
| for i, _ in enumerate(dataloader): | |
| path, targets = _ | |
| image = dataloader.dataset.load_item(path) | |
| print(f"image: {image.shape}, target: {targets}") | |
| if display_samples: | |
| plt.figure() | |
| img = image[0].permute([1, 2, 0]).numpy() | |
| plt.imshow(img) | |
| # plt.savefig("./img_" + str(i) + ".png") | |
| plt.show() | |
| if i >= 9: | |
| break | |
| ########################### | |
| # run the functions below # | |
| ########################### | |
| # run_dataset() | |
| run_dataloader(False) | |