|
|
|
|
|
|
|
|
|
|
| import sys
|
|
|
| import lmdb
|
|
|
| sys.path.append('.')
|
|
|
| import os
|
| import math
|
| import yaml
|
| import glob
|
| import json
|
|
|
| import numpy as np
|
| from copy import deepcopy
|
| import cv2
|
| import random
|
| from PIL import Image
|
| from collections import defaultdict
|
|
|
| import torch
|
| from torch.autograd import Variable
|
| from torch.utils import data
|
| from torchvision import transforms as T
|
|
|
| import albumentations as A
|
|
|
| from .albu import IsotropicResize
|
|
|
| FFpp_pool=['FaceForensics++','FaceShifter','DeepFakeDetection','FF-DF','FF-F2F','FF-FS','FF-NT']
|
| import pdb
|
|
|
| def all_in_pool(inputs,pool):
|
| for each in inputs:
|
| if each not in pool:
|
| return False
|
| return True
|
|
|
|
|
| class DeepfakeAbstractBaseDataset(data.Dataset):
|
| """
|
| Abstract base class for all deepfake datasets.
|
| """
|
| def __init__(self, config=None, mode='train'):
|
| """Initializes the dataset object.
|
|
|
| Args:
|
| config (dict): A dictionary containing configuration parameters.
|
| mode (str): A string indicating the mode (train or test).
|
|
|
| Raises:
|
| NotImplementedError: If mode is not train or test.
|
| """
|
|
|
|
|
| self.config = config
|
| self.mode = mode
|
| self.compression = config['compression']
|
| self.frame_num = config['frame_num'][mode]
|
|
|
|
|
| self.video_level = config.get('video_mode', False)
|
| self.clip_size = config.get('clip_size', None)
|
| self.lmdb = config.get('lmdb', False)
|
|
|
| self.image_list = []
|
| self.label_list = []
|
|
|
|
|
| if mode == 'train':
|
| dataset_list = config['train_dataset']
|
|
|
|
|
| image_list, label_list = [], []
|
| for one_data in dataset_list:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| tmp_image, tmp_label, tmp_name = self.collect_img_and_label_for_one_dataset(one_data)
|
| image_list.extend(tmp_image)
|
| label_list.extend(tmp_label)
|
| if self.lmdb:
|
| if len(dataset_list)>1:
|
| if all_in_pool(dataset_list,FFpp_pool):
|
| lmdb_path = os.path.join(config['lmdb_dir'], f"FaceForensics++_lmdb")
|
| self.env = lmdb.open(lmdb_path, create=False, subdir=True, readonly=True, lock=False)
|
| else:
|
| raise ValueError('Training with multiple dataset and lmdb is not implemented yet.')
|
| else:
|
| lmdb_path = os.path.join(config['lmdb_dir'], f"{dataset_list[0] if dataset_list[0] not in FFpp_pool else 'FaceForensics++'}_lmdb")
|
| self.env = lmdb.open(lmdb_path, create=False, subdir=True, readonly=True, lock=False)
|
| elif mode == 'test':
|
| one_data = config['test_dataset']
|
|
|
| image_list, label_list, name_list = self.collect_img_and_label_for_one_dataset(one_data)
|
| if self.lmdb:
|
| lmdb_path = os.path.join(config['lmdb_dir'], f"{one_data}_lmdb" if one_data not in FFpp_pool else 'FaceForensics++_lmdb')
|
| self.env = lmdb.open(lmdb_path, create=False, subdir=True, readonly=True, lock=False)
|
| else:
|
| raise NotImplementedError('Only train and test modes are supported.')
|
|
|
| assert len(image_list)!=0 and len(label_list)!=0, f"Collect nothing for {mode} mode!"
|
| self.image_list, self.label_list = image_list, label_list
|
|
|
|
|
|
|
| self.data_dict = {
|
| 'image': self.image_list,
|
| 'label': self.label_list,
|
| }
|
|
|
| self.transform = self.init_data_aug_method()
|
|
|
| def init_data_aug_method(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| trans = A.Compose([
|
| A.HorizontalFlip(p=0.5),
|
| A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
|
| A.HueSaturationValue(p=0.3),
|
| A.ImageCompression(quality_lower=40, quality_upper=100, p=0.1),
|
| A.GaussNoise(p=0.1),
|
| A.MotionBlur(p=0.1),
|
| A.CLAHE(p=0.1),
|
| A.ChannelShuffle(p=0.1),
|
| A.Cutout(p=0.1),
|
| A.RandomGamma(p=0.3),
|
| A.GlassBlur(p=0.3),
|
| ])
|
|
|
| return trans
|
|
|
| def rescale_landmarks(self, landmarks, original_size=256, new_size=224):
|
| scale_factor = new_size / original_size
|
| rescaled_landmarks = landmarks * scale_factor
|
| return rescaled_landmarks
|
|
|
|
|
| def collect_img_and_label_for_one_dataset(self, dataset_name: str):
|
| """Collects image and label lists.
|
|
|
| Args:
|
| dataset_name (str): A list containing one dataset information. e.g., 'FF-F2F'
|
|
|
| Returns:
|
| list: A list of image paths.
|
| list: A list of labels.
|
|
|
| Raises:
|
| ValueError: If image paths or labels are not found.
|
| NotImplementedError: If the dataset is not implemented yet.
|
| """
|
|
|
| label_list = []
|
| frame_path_list = []
|
|
|
|
|
| video_name_list = []
|
|
|
|
|
| if not os.path.exists(self.config['dataset_json_folder']):
|
| self.config['dataset_json_folder'] = self.config['dataset_json_folder'].replace('/Youtu_Pangu_Security_Public', '/Youtu_Pangu_Security/public')
|
| try:
|
| with open(os.path.join(self.config['dataset_json_folder'], dataset_name + '.json'), 'r') as f:
|
| dataset_info = json.load(f)
|
| except Exception as e:
|
| print(e)
|
| raise ValueError(f'dataset {dataset_name} not exist!')
|
|
|
|
|
|
|
| cp = None
|
| if dataset_name == 'FaceForensics++_c40':
|
| dataset_name = 'FaceForensics++'
|
| cp = 'c40'
|
| elif dataset_name == 'FF-DF_c40':
|
| dataset_name = 'FF-DF'
|
| cp = 'c40'
|
| elif dataset_name == 'FF-F2F_c40':
|
| dataset_name = 'FF-F2F'
|
| cp = 'c40'
|
| elif dataset_name == 'FF-FS_c40':
|
| dataset_name = 'FF-FS'
|
| cp = 'c40'
|
| elif dataset_name == 'FF-NT_c40':
|
| dataset_name = 'FF-NT'
|
| cp = 'c40'
|
|
|
| for label in dataset_info[dataset_name]:
|
| sub_dataset_info = dataset_info[dataset_name][label][self.mode]
|
|
|
|
|
| if cp == None and dataset_name in ['FF-DF', 'FF-F2F', 'FF-FS', 'FF-NT', 'FaceForensics++','DeepFakeDetection','FaceShifter','ivy_fake_train','ivy_fake_test',
|
| 'ivy_fake_test_Deepfakes','ivy_fake_test_NeuralTextures','ivy_fake_test_FaceSwap','ivy_fake_test_Face2Face']:
|
| sub_dataset_info = sub_dataset_info[self.compression]
|
| elif cp == 'c40' and dataset_name in ['FF-DF', 'FF-F2F', 'FF-FS', 'FF-NT', 'FaceForensics++','DeepFakeDetection','FaceShifter']:
|
| sub_dataset_info = sub_dataset_info['c40']
|
|
|
|
|
|
|
| for video_name, video_info in sub_dataset_info.items():
|
|
|
|
|
| unique_video_name = video_info['label'] + '_' + video_name
|
|
|
|
|
| if video_info['label'] not in self.config['label_dict']:
|
| raise ValueError(f'Label {video_info["label"]} is not found in the configuration file.')
|
| label = self.config['label_dict'][video_info['label']]
|
| frame_paths = video_info['frames']
|
|
|
| if '\\' in frame_paths[0]:
|
| frame_paths = sorted(frame_paths, key=lambda x: int(x.split('\\')[-1].split('.')[0]))
|
| else:
|
| frame_paths = sorted(frame_paths, key=lambda x: int(x.split('/')[-1].split('.')[0]))
|
|
|
|
|
|
|
| total_frames = len(frame_paths)
|
| if self.frame_num < total_frames:
|
| total_frames = self.frame_num
|
| if self.video_level:
|
|
|
| start_frame = random.randint(0, total_frames - self.frame_num) if self.mode == 'train' else 0
|
| frame_paths = frame_paths[start_frame:start_frame + self.frame_num]
|
| else:
|
|
|
| step = total_frames // self.frame_num
|
| frame_paths = [frame_paths[i] for i in range(0, total_frames, step)][:self.frame_num]
|
|
|
|
|
| if self.video_level:
|
| if self.clip_size is None:
|
| raise ValueError('clip_size must be specified when video_level is True.')
|
|
|
| if total_frames >= self.clip_size:
|
|
|
| selected_clips = []
|
|
|
|
|
| num_clips = total_frames // self.clip_size
|
|
|
| if num_clips > 1:
|
|
|
| clip_step = (total_frames - self.clip_size) // (num_clips - 1)
|
|
|
|
|
| for i in range(num_clips):
|
|
|
| start_frame = random.randrange(i * clip_step, min((i + 1) * clip_step, total_frames - self.clip_size + 1)) if self.mode == 'train' else i * clip_step
|
| continuous_frames = frame_paths[start_frame:start_frame + self.clip_size]
|
| assert len(continuous_frames) == self.clip_size, 'clip_size is not equal to the length of frame_path_list'
|
| selected_clips.append(continuous_frames)
|
|
|
| else:
|
| start_frame = random.randrange(0, total_frames - self.clip_size + 1) if self.mode == 'train' else 0
|
| continuous_frames = frame_paths[start_frame:start_frame + self.clip_size]
|
| assert len(continuous_frames)==self.clip_size, 'clip_size is not equal to the length of frame_path_list'
|
| selected_clips.append(continuous_frames)
|
|
|
|
|
| label_list.extend([label] * len(selected_clips))
|
| frame_path_list.extend(selected_clips)
|
|
|
| video_name_list.extend([unique_video_name] * len(selected_clips))
|
|
|
| else:
|
| print(f"Skipping video {unique_video_name} because it has less than clip_size ({self.clip_size}) frames ({total_frames}).")
|
|
|
|
|
| else:
|
|
|
| label_list.extend([label] * total_frames)
|
| frame_path_list.extend(frame_paths)
|
|
|
| video_name_list.extend([unique_video_name] * len(frame_paths))
|
|
|
|
|
| shuffled = list(zip(label_list, frame_path_list, video_name_list))
|
| random.shuffle(shuffled)
|
| label_list, frame_path_list, video_name_list = zip(*shuffled)
|
|
|
| return frame_path_list, label_list, video_name_list
|
|
|
|
|
| def load_rgb(self, file_path):
|
| """
|
| Load an RGB image from a file path and resize it to a specified resolution.
|
|
|
| Args:
|
| file_path: A string indicating the path to the image file.
|
|
|
| Returns:
|
| An Image object containing the loaded and resized image.
|
|
|
| Raises:
|
| ValueError: If the loaded image is None.
|
| """
|
| size = self.config['resolution']
|
| if not self.lmdb:
|
|
|
|
|
| if not os.path.exists(file_path):
|
| file_path = file_path.replace('\\', '/')
|
| assert os.path.exists(file_path), f"{file_path} does not exist"
|
| img = cv2.imread(file_path)
|
| if img is None:
|
| raise ValueError('Loaded image is None: {}'.format(file_path))
|
| elif self.lmdb:
|
| with self.env.begin(write=False) as txn:
|
|
|
| if file_path[0]=='.':
|
| file_path=file_path.replace('./datasets\\','')
|
|
|
| image_bin = txn.get(file_path.encode())
|
| image_buf = np.frombuffer(image_bin, dtype=np.uint8)
|
| img = cv2.imdecode(image_buf, cv2.IMREAD_COLOR)
|
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
|
| return Image.fromarray(np.array(img, dtype=np.uint8))
|
|
|
|
|
| def load_mask(self, file_path):
|
| """
|
| Load a binary mask image from a file path and resize it to a specified resolution.
|
|
|
| Args:
|
| file_path: A string indicating the path to the mask file.
|
|
|
| Returns:
|
| A numpy array containing the loaded and resized mask.
|
|
|
| Raises:
|
| None.
|
| """
|
| size = self.config['resolution']
|
| if file_path is None:
|
| return np.zeros((size, size, 1))
|
| if not self.lmdb:
|
|
|
|
|
| if os.path.exists(file_path):
|
| mask = cv2.imread(file_path, 0)
|
| if mask is None:
|
| mask = np.zeros((size, size))
|
| else:
|
| return np.zeros((size, size, 1))
|
| else:
|
| with self.env.begin(write=False) as txn:
|
|
|
| if file_path[0]=='.':
|
| file_path=file_path.replace('./datasets\\','')
|
|
|
| image_bin = txn.get(file_path.encode())
|
| if image_bin is None:
|
| mask = np.zeros((size, size,3))
|
| else:
|
| image_buf = np.frombuffer(image_bin, dtype=np.uint8)
|
| mask = cv2.imdecode(image_buf, cv2.IMREAD_COLOR)
|
| mask = cv2.resize(mask, (size, size)) / 255
|
| mask = np.expand_dims(mask, axis=2)
|
| return np.float32(mask)
|
|
|
| def load_landmark(self, file_path):
|
| """
|
| Load 2D facial landmarks from a file path.
|
|
|
| Args:
|
| file_path: A string indicating the path to the landmark file.
|
|
|
| Returns:
|
| A numpy array containing the loaded landmarks.
|
|
|
| Raises:
|
| None.
|
| """
|
| if file_path is None:
|
| return np.zeros((81, 2))
|
| if not self.lmdb:
|
|
|
|
|
| if os.path.exists(file_path):
|
| landmark = np.load(file_path)
|
| else:
|
| return np.zeros((81, 2))
|
| else:
|
| with self.env.begin(write=False) as txn:
|
|
|
| if file_path[0]=='.':
|
| file_path=file_path.replace('./datasets\\','')
|
| binary = txn.get(file_path.encode())
|
| landmark = np.frombuffer(binary, dtype=np.uint32).reshape((81, 2))
|
| landmark=self.rescale_landmarks(np.float32(landmark), original_size=256, new_size=self.config['resolution'])
|
| return landmark
|
|
|
| def to_tensor(self, img):
|
| """
|
| Convert an image to a PyTorch tensor.
|
| """
|
| return T.ToTensor()(img)
|
|
|
| def normalize(self, img):
|
| """
|
| Normalize an image.
|
| """
|
| mean = self.config['mean']
|
| std = self.config['std']
|
| normalize = T.Normalize(mean=mean, std=std)
|
| return normalize(img)
|
|
|
| def data_aug(self, img, landmark=None, mask=None, augmentation_seed=None):
|
| """
|
| Apply data augmentation to an image, landmark, and mask.
|
|
|
| Args:
|
| img: An Image object containing the image to be augmented.
|
| landmark: A numpy array containing the 2D facial landmarks to be augmented.
|
| mask: A numpy array containing the binary mask to be augmented.
|
|
|
| Returns:
|
| The augmented image, landmark, and mask.
|
| """
|
|
|
|
|
| if augmentation_seed is not None:
|
| random.seed(augmentation_seed)
|
| np.random.seed(augmentation_seed)
|
|
|
|
|
| kwargs = {'image': img}
|
|
|
|
|
| if landmark is not None:
|
| kwargs['keypoints'] = landmark
|
| kwargs['keypoint_params'] = A.KeypointParams(format='xy')
|
| if mask is not None:
|
| mask = mask.squeeze(2)
|
| if mask.max() > 0:
|
| kwargs['mask'] = mask
|
|
|
|
|
| transformed = self.transform(**kwargs)
|
|
|
|
|
|
|
|
|
| augmented_img = kwargs['image']
|
| augmented_landmark = transformed.get('keypoints')
|
| augmented_mask = transformed.get('mask',mask)
|
|
|
|
|
| if augmented_landmark is not None:
|
| augmented_landmark = np.array(augmented_landmark)
|
|
|
|
|
| if augmentation_seed is not None:
|
| random.seed()
|
| np.random.seed()
|
|
|
| return augmented_img, augmented_landmark, augmented_mask
|
|
|
| def __getitem__(self, index, no_norm=False):
|
| """
|
| Returns the data point at the given index.
|
|
|
| Args:
|
| index (int): The index of the data point.
|
|
|
| Returns:
|
| A tuple containing the image tensor, the label tensor, the landmark tensor,
|
| and the mask tensor.
|
| """
|
|
|
| image_paths = self.data_dict['image'][index]
|
| label = self.data_dict['label'][index]
|
|
|
|
|
|
|
| if not isinstance(image_paths, list):
|
| image_paths = [image_paths]
|
|
|
| image_tensors = []
|
| landmark_tensors = []
|
| mask_tensors = []
|
| augmentation_seed = None
|
|
|
| for image_path in image_paths:
|
|
|
| if self.video_level and image_path == image_paths[0]:
|
| augmentation_seed = random.randint(0, 2**32 - 1)
|
|
|
|
|
| mask_path = image_path.replace('frames', 'masks')
|
| landmark_path = image_path.replace('frames', 'landmarks').replace('.png', '.npy')
|
|
|
|
|
| try:
|
| image = self.load_rgb(image_path)
|
| except Exception as e:
|
|
|
| print(f"Error loading image at index {index}: {e}")
|
| return self.__getitem__(0)
|
| image = np.array(image)
|
|
|
|
|
| if self.config['with_mask']:
|
| mask = self.load_mask(mask_path)
|
| else:
|
| mask = None
|
| if self.config['with_landmark']:
|
| landmarks = self.load_landmark(landmark_path)
|
| else:
|
| landmarks = None
|
|
|
|
|
| if self.mode == 'train' and self.config['use_data_augmentation']:
|
| image_trans, landmarks_trans, mask_trans = self.data_aug(image, landmarks, mask, augmentation_seed)
|
| else:
|
|
|
|
|
| image_trans, landmarks_trans, mask_trans = deepcopy(image), deepcopy(landmarks), deepcopy(mask)
|
|
|
|
|
|
|
| if not no_norm:
|
| image_trans = self.normalize(self.to_tensor(image_trans))
|
| if self.config['with_landmark']:
|
| landmarks_trans = torch.from_numpy(landmarks)
|
| if self.config['with_mask']:
|
| mask_trans = torch.from_numpy(mask_trans)
|
|
|
| image_tensors.append(image_trans)
|
| landmark_tensors.append(landmarks_trans)
|
| mask_tensors.append(mask_trans)
|
|
|
| if self.video_level:
|
|
|
| image_tensors = torch.stack(image_tensors, dim=0)
|
|
|
| if not any(landmark is None or (isinstance(landmark, list) and None in landmark) for landmark in landmark_tensors):
|
| landmark_tensors = torch.stack(landmark_tensors, dim=0)
|
| if not any(m is None or (isinstance(m, list) and None in m) for m in mask_tensors):
|
| mask_tensors = torch.stack(mask_tensors, dim=0)
|
| else:
|
|
|
| image_tensors = image_tensors[0]
|
|
|
| if not any(landmark is None or (isinstance(landmark, list) and None in landmark) for landmark in landmark_tensors):
|
| landmark_tensors = landmark_tensors[0]
|
| if not any(m is None or (isinstance(m, list) and None in m) for m in mask_tensors):
|
| mask_tensors = mask_tensors[0]
|
|
|
| return image_tensors, label, landmark_tensors, mask_tensors
|
|
|
| @staticmethod
|
| def collate_fn(batch):
|
| """
|
| Collate a batch of data points.
|
|
|
| Args:
|
| batch (list): A list of tuples containing the image tensor, the label tensor,
|
| the landmark tensor, and the mask tensor.
|
|
|
| Returns:
|
| A tuple containing the image tensor, the label tensor, the landmark tensor,
|
| and the mask tensor.
|
| """
|
|
|
| images, labels, landmarks, masks = zip(*batch)
|
|
|
|
|
| images = torch.stack(images, dim=0)
|
| labels = torch.LongTensor(labels)
|
|
|
|
|
| if not any(landmark is None or (isinstance(landmark, list) and None in landmark) for landmark in landmarks):
|
| landmarks = torch.stack(landmarks, dim=0)
|
| else:
|
| landmarks = None
|
|
|
| if not any(m is None or (isinstance(m, list) and None in m) for m in masks):
|
| masks = torch.stack(masks, dim=0)
|
| else:
|
| masks = None
|
|
|
|
|
| data_dict = {}
|
| data_dict['image'] = images
|
| data_dict['label'] = labels
|
| data_dict['landmark'] = landmarks
|
| data_dict['mask'] = masks
|
| return data_dict
|
|
|
| def __len__(self):
|
| """
|
| Return the length of the dataset.
|
|
|
| Args:
|
| None.
|
|
|
| Returns:
|
| An integer indicating the length of the dataset.
|
|
|
| Raises:
|
| AssertionError: If the number of images and labels in the dataset are not equal.
|
| """
|
| assert len(self.image_list) == len(self.label_list), 'Number of images and labels are not equal'
|
| return len(self.image_list)
|
|
|
|
|
| if __name__ == "__main__":
|
| with open('/data/home/zhiyuanyan/DeepfakeBench/training/config/detector/video_baseline.yaml', 'r') as f:
|
| config = yaml.safe_load(f)
|
| train_set = DeepfakeAbstractBaseDataset(
|
| config = config,
|
| mode = 'train',
|
| )
|
| train_data_loader = \
|
| torch.utils.data.DataLoader(
|
| dataset=train_set,
|
| batch_size=config['train_batchSize'],
|
| shuffle=True,
|
| num_workers=0,
|
| collate_fn=train_set.collate_fn,
|
| )
|
| from tqdm import tqdm
|
| for iteration, batch in enumerate(tqdm(train_data_loader)):
|
|
|
| ...
|
|
|
|
|
|
|