Spaces:
Running
Running
| """ | |
| Video Face Manipulation Detection Through Ensemble of CNNs | |
| Image and Sound Processing Lab - Politecnico di Milano | |
| Nicolò Bonettini | |
| Edoardo Daniele Cannas | |
| Sara Mandelli | |
| Luca Bondi | |
| Paolo Bestagini | |
| """ | |
| import argparse | |
| import os | |
| import shutil | |
| import warnings | |
| import albumentations as A | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.multiprocessing | |
| from torchvision.transforms import ToPILImage, ToTensor | |
| from isplutils import utils, split | |
| torch.multiprocessing.set_sharing_strategy('file_system') | |
| import torch.nn as nn | |
| from albumentations.pytorch import ToTensorV2 | |
| from sklearn.metrics import roc_auc_score | |
| from tensorboardX import SummaryWriter | |
| from torch import optim | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from PIL import ImageChops, Image | |
| from architectures import fornet | |
| from isplutils.data import FrameFaceIterableDataset, load_face | |
| def main(): | |
| # Args | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--net', type=str, help='Net model class', required=True) | |
| parser.add_argument('--traindb', type=str, help='Training datasets', nargs='+', choices=split.available_datasets, | |
| required=True) | |
| parser.add_argument('--valdb', type=str, help='Validation datasets', nargs='+', choices=split.available_datasets, | |
| required=True) | |
| parser.add_argument('--dfdc_faces_df_path', type=str, action='store', | |
| help='Path to the Pandas Dataframe obtained from extract_faces.py on the DFDC dataset. ' | |
| 'Required for training/validating on the DFDC dataset.') | |
| parser.add_argument('--dfdc_faces_dir', type=str, action='store', | |
| help='Path to the directory containing the faces extracted from the DFDC dataset. ' | |
| 'Required for training/validating on the DFDC dataset.') | |
| parser.add_argument('--ffpp_faces_df_path', type=str, action='store', | |
| help='Path to the Pandas Dataframe obtained from extract_faces.py on the FF++ dataset. ' | |
| 'Required for training/validating on the FF++ dataset.') | |
| parser.add_argument('--ffpp_faces_dir', type=str, action='store', | |
| help='Path to the directory containing the faces extracted from the FF++ dataset. ' | |
| 'Required for training/validating on the FF++ dataset.') | |
| parser.add_argument('--face', type=str, help='Face crop or scale', required=True, | |
| choices=['scale', 'tight']) | |
| parser.add_argument('--size', type=int, help='Train patch size', required=True) | |
| parser.add_argument('--batch', type=int, help='Batch size to fit in GPU memory', default=32) | |
| parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate') | |
| parser.add_argument('--valint', type=int, help='Validation interval (iterations)', default=500) | |
| parser.add_argument('--patience', type=int, help='Patience before dropping the LR [validation intervals]', | |
| default=10) | |
| parser.add_argument('--maxiter', type=int, help='Maximum number of iterations', default=20000) | |
| parser.add_argument('--init', type=str, help='Weight initialization file') | |
| parser.add_argument('--scratch', action='store_true', help='Train from scratch') | |
| parser.add_argument('--trainsamples', type=int, help='Limit the number of train samples per epoch', default=-1) | |
| parser.add_argument('--valsamples', type=int, help='Limit the number of validation samples per epoch', | |
| default=6000) | |
| parser.add_argument('--logint', type=int, help='Training log interval (iterations)', default=100) | |
| parser.add_argument('--workers', type=int, help='Num workers for data loaders', default=6) | |
| parser.add_argument('--device', type=int, help='GPU device id', default=0) | |
| parser.add_argument('--seed', type=int, help='Random seed', default=0) | |
| parser.add_argument('--debug', action='store_true', help='Activate debug') | |
| parser.add_argument('--suffix', type=str, help='Suffix to default tag') | |
| parser.add_argument('--attention', action='store_true', | |
| help='Enable Tensorboard log of attention masks') | |
| parser.add_argument('--log_dir', type=str, help='Directory for saving the training logs', | |
| default='runs/binclass/') | |
| parser.add_argument('--models_dir', type=str, help='Directory for saving the models weights', | |
| default='weights/binclass/') | |
| args = parser.parse_args() | |
| # Parse arguments | |
| net_class = getattr(fornet, args.net) | |
| train_datasets = args.traindb | |
| val_datasets = args.valdb | |
| dfdc_df_path = args.dfdc_faces_df_path | |
| ffpp_df_path = args.ffpp_faces_df_path | |
| dfdc_faces_dir = args.dfdc_faces_dir | |
| ffpp_faces_dir = args.ffpp_faces_dir | |
| face_policy = args.face | |
| face_size = args.size | |
| batch_size = args.batch | |
| initial_lr = args.lr | |
| validation_interval = args.valint | |
| patience = args.patience | |
| max_num_iterations = args.maxiter | |
| initial_model = args.init | |
| train_from_scratch = args.scratch | |
| max_train_samples = args.trainsamples | |
| max_val_samples = args.valsamples | |
| log_interval = args.logint | |
| num_workers = args.workers | |
| device = torch.device('cuda:{:d}'.format(args.device)) if torch.cuda.is_available() else torch.device('cpu') | |
| seed = args.seed | |
| debug = args.debug | |
| suffix = args.suffix | |
| enable_attention = args.attention | |
| weights_folder = args.models_dir | |
| logs_folder = args.log_dir | |
| # Random initialization | |
| np.random.seed(seed) | |
| torch.random.manual_seed(seed) | |
| # Load net | |
| net: nn.Module = net_class().to(device) | |
| # Loss and optimizers | |
| criterion = nn.BCEWithLogitsLoss() | |
| min_lr = initial_lr * 1e-5 | |
| optimizer = optim.Adam(net.get_trainable_parameters(), lr=initial_lr) | |
| lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer=optimizer, | |
| mode='min', | |
| factor=0.1, | |
| patience=patience, | |
| cooldown=2 * patience, | |
| min_lr=min_lr, | |
| ) | |
| tag = utils.make_train_tag(net_class=net_class, | |
| traindb=train_datasets, | |
| face_policy=face_policy, | |
| patch_size=face_size, | |
| seed=seed, | |
| suffix=suffix, | |
| debug=debug, | |
| ) | |
| # Model checkpoint paths | |
| bestval_path = os.path.join(weights_folder, tag, 'bestval.pth') | |
| last_path = os.path.join(weights_folder, tag, 'last.pth') | |
| periodic_path = os.path.join(weights_folder, tag, 'it{:06d}.pth') | |
| os.makedirs(os.path.join(weights_folder, tag), exist_ok=True) | |
| # Load model | |
| val_loss = min_val_loss = 10 | |
| epoch = iteration = 0 | |
| net_state = None | |
| opt_state = None | |
| if initial_model is not None: | |
| # If given load initial model | |
| print('Loading model form: {}'.format(initial_model)) | |
| state = torch.load(initial_model, map_location='cpu') | |
| net_state = state['net'] | |
| elif not train_from_scratch and os.path.exists(last_path): | |
| print('Loading model form: {}'.format(last_path)) | |
| state = torch.load(last_path, map_location='cpu') | |
| net_state = state['net'] | |
| opt_state = state['opt'] | |
| iteration = state['iteration'] + 1 | |
| epoch = state['epoch'] | |
| if not train_from_scratch and os.path.exists(bestval_path): | |
| state = torch.load(bestval_path, map_location='cpu') | |
| min_val_loss = state['val_loss'] | |
| if net_state is not None: | |
| incomp_keys = net.load_state_dict(net_state, strict=False) | |
| print(incomp_keys) | |
| if opt_state is not None: | |
| for param_group in opt_state['param_groups']: | |
| param_group['lr'] = initial_lr | |
| optimizer.load_state_dict(opt_state) | |
| # Initialize Tensorboard | |
| logdir = os.path.join(logs_folder, tag) | |
| if iteration == 0: | |
| # If training from scratch or initialization remove history if exists | |
| shutil.rmtree(logdir, ignore_errors=True) | |
| # TensorboardX instance | |
| tb = SummaryWriter(logdir=logdir) | |
| if iteration == 0: | |
| dummy = torch.randn((1, 3, face_size, face_size), device=device) | |
| dummy = dummy.to(device) | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| tb.add_graph(net, [dummy, ], verbose=False) | |
| transformer = utils.get_transformer(face_policy=face_policy, patch_size=face_size, | |
| net_normalizer=net.get_normalizer(), train=True) | |
| # Datasets and data loaders | |
| print('Loading data') | |
| # Check if paths for DFDC and FF++ extracted faces and DataFrames are provided | |
| for dataset in train_datasets: | |
| if dataset.split('-')[0] == 'dfdc' and (dfdc_df_path is None or dfdc_faces_dir is None): | |
| raise RuntimeError('Specify DataFrame and directory for DFDC faces for training!') | |
| elif dataset.split('-')[0] == 'ff' and (ffpp_df_path is None or ffpp_faces_dir is None): | |
| raise RuntimeError('Specify DataFrame and directory for FF++ faces for training!') | |
| for dataset in val_datasets: | |
| if dataset.split('-')[0] == 'dfdc' and (dfdc_df_path is None or dfdc_faces_dir is None): | |
| raise RuntimeError('Specify DataFrame and directory for DFDC faces for validation!') | |
| elif dataset.split('-')[0] == 'ff' and (ffpp_df_path is None or ffpp_faces_dir is None): | |
| raise RuntimeError('Specify DataFrame and directory for FF++ faces for validation!') | |
| # Load splits with the make_splits function | |
| splits = split.make_splits(dfdc_df=dfdc_df_path, ffpp_df=ffpp_df_path, dfdc_dir=dfdc_faces_dir, ffpp_dir=ffpp_faces_dir, | |
| dbs={'train': train_datasets, 'val': val_datasets}) | |
| train_dfs = [splits['train'][db][0] for db in splits['train']] | |
| train_roots = [splits['train'][db][1] for db in splits['train']] | |
| val_roots = [splits['val'][db][1] for db in splits['val']] | |
| val_dfs = [splits['val'][db][0] for db in splits['val']] | |
| train_dataset = FrameFaceIterableDataset(roots=train_roots, | |
| dfs=train_dfs, | |
| scale=face_policy, | |
| num_samples=max_train_samples, | |
| transformer=transformer, | |
| size=face_size, | |
| ) | |
| val_dataset = FrameFaceIterableDataset(roots=val_roots, | |
| dfs=val_dfs, | |
| scale=face_policy, | |
| num_samples=max_val_samples, | |
| transformer=transformer, | |
| size=face_size, | |
| ) | |
| train_loader = DataLoader(train_dataset, num_workers=num_workers, batch_size=batch_size, ) | |
| val_loader = DataLoader(val_dataset, num_workers=num_workers, batch_size=batch_size, ) | |
| print('Training samples: {}'.format(len(train_dataset))) | |
| print('Validation samples: {}'.format(len(val_dataset))) | |
| if len(train_dataset) == 0: | |
| print('No training samples. Halt.') | |
| return | |
| if len(val_dataset) == 0: | |
| print('No validation samples. Halt.') | |
| return | |
| stop = False | |
| while not stop: | |
| # Training | |
| optimizer.zero_grad() | |
| train_loss = train_num = 0 | |
| train_pred_list = [] | |
| train_labels_list = [] | |
| for train_batch in tqdm(train_loader, desc='Epoch {:03d}'.format(epoch), leave=False, | |
| total=len(train_loader) // train_loader.batch_size): | |
| net.train() | |
| batch_data, batch_labels = train_batch | |
| train_batch_num = len(batch_labels) | |
| train_num += train_batch_num | |
| train_labels_list.append(batch_labels.numpy().flatten()) | |
| train_batch_loss, train_batch_pred = batch_forward(net, device, criterion, batch_data, batch_labels) | |
| train_pred_list.append(train_batch_pred.flatten()) | |
| if torch.isnan(train_batch_loss): | |
| raise ValueError('NaN loss') | |
| train_loss += train_batch_loss.item() * train_batch_num | |
| # Optimization | |
| train_batch_loss.backward() | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| # Logging | |
| if iteration > 0 and (iteration % log_interval == 0): | |
| train_loss /= train_num | |
| tb.add_scalar('train/loss', train_loss, iteration) | |
| tb.add_scalar('lr', optimizer.param_groups[0]['lr'], iteration) | |
| tb.add_scalar('epoch', epoch, iteration) | |
| # Checkpoint | |
| save_model(net, optimizer, train_loss, val_loss, iteration, batch_size, epoch, last_path) | |
| train_loss = train_num = 0 | |
| # Validation | |
| if iteration > 0 and (iteration % validation_interval == 0): | |
| # Model checkpoint | |
| save_model(net, optimizer, train_loss, val_loss, iteration, batch_size, epoch, | |
| periodic_path.format(iteration)) | |
| # Train cumulative stats | |
| train_labels = np.concatenate(train_labels_list) | |
| train_pred = np.concatenate(train_pred_list) | |
| train_labels_list = [] | |
| train_pred_list = [] | |
| train_roc_auc = roc_auc_score(train_labels, train_pred) | |
| tb.add_scalar('train/roc_auc', train_roc_auc, iteration) | |
| tb.add_pr_curve('train/pr', train_labels, train_pred, iteration) | |
| # Validation | |
| val_loss = validation_routine(net, device, val_loader, criterion, tb, iteration, 'val') | |
| tb.flush() | |
| # LR Scheduler | |
| lr_scheduler.step(val_loss) | |
| # Model checkpoint | |
| if val_loss < min_val_loss: | |
| min_val_loss = val_loss | |
| save_model(net, optimizer, train_loss, val_loss, iteration, batch_size, epoch, bestval_path) | |
| # Attention | |
| if enable_attention and hasattr(net, 'get_attention'): | |
| net.eval() | |
| # For each dataframe show the attention for a real,fake couple of frames | |
| for df, root, sample_idx, tag in [ | |
| (train_dfs[0], train_roots[0], train_dfs[0][train_dfs[0]['label'] == False].index[0], | |
| 'train/att/real'), | |
| (train_dfs[0], train_roots[0], train_dfs[0][train_dfs[0]['label'] == True].index[0], | |
| 'train/att/fake'), | |
| ]: | |
| record = df.loc[sample_idx] | |
| tb_attention(tb, tag, iteration, net, device, face_size, face_policy, | |
| transformer, root, record) | |
| if optimizer.param_groups[0]['lr'] == min_lr: | |
| print('Reached minimum learning rate. Stopping.') | |
| stop = True | |
| break | |
| iteration += 1 | |
| if iteration > max_num_iterations: | |
| print('Maximum number of iterations reached') | |
| stop = True | |
| break | |
| # End of iteration | |
| epoch += 1 | |
| # Needed to flush out last events | |
| tb.close() | |
| print('Completed') | |
| def tb_attention(tb: SummaryWriter, | |
| tag: str, | |
| iteration: int, | |
| net: nn.Module, | |
| device: torch.device, | |
| patch_size_load: int, | |
| face_crop_scale: str, | |
| val_transformer: A.BasicTransform, | |
| root: str, | |
| record: pd.Series, | |
| ): | |
| # Crop face | |
| sample_t = load_face(record=record, root=root, size=patch_size_load, scale=face_crop_scale, | |
| transformer=val_transformer) | |
| sample_t_clean = load_face(record=record, root=root, size=patch_size_load, scale=face_crop_scale, | |
| transformer=ToTensorV2()) | |
| if torch.cuda.is_available(): | |
| sample_t = sample_t.cuda(device) | |
| # Transform | |
| # Feed to net | |
| with torch.no_grad(): | |
| att: torch.Tensor = net.get_attention(sample_t.unsqueeze(0))[0].cpu() | |
| att_img: Image.Image = ToPILImage()(att) | |
| sample_img = ToPILImage()(sample_t_clean) | |
| att_img = att_img.resize(sample_img.size, resample=Image.NEAREST).convert('RGB') | |
| sample_att_img = ImageChops.multiply(sample_img, att_img) | |
| sample_att = ToTensor()(sample_att_img) | |
| tb.add_image(tag=tag, img_tensor=sample_att, global_step=iteration) | |
| def batch_forward(net: nn.Module, device: torch.device, criterion, data: torch.Tensor, labels: torch.Tensor) -> ( | |
| torch.Tensor, float, int): | |
| data = data.to(device) | |
| labels = labels.to(device) | |
| out = net(data) | |
| pred = torch.sigmoid(out).detach().cpu().numpy() | |
| loss = criterion(out, labels) | |
| return loss, pred | |
| def validation_routine(net, device, val_loader, criterion, tb, iteration, tag: str, loader_len_norm: int = None): | |
| net.eval() | |
| loader_len_norm = loader_len_norm if loader_len_norm is not None else val_loader.batch_size | |
| val_num = 0 | |
| val_loss = 0. | |
| pred_list = list() | |
| labels_list = list() | |
| for val_data in tqdm(val_loader, desc='Validation', leave=False, total=len(val_loader) // loader_len_norm): | |
| batch_data, batch_labels = val_data | |
| val_batch_num = len(batch_labels) | |
| labels_list.append(batch_labels.flatten()) | |
| with torch.no_grad(): | |
| val_batch_loss, val_batch_pred = batch_forward(net, device, criterion, batch_data, | |
| batch_labels) | |
| pred_list.append(val_batch_pred.flatten()) | |
| val_num += val_batch_num | |
| val_loss += val_batch_loss.item() * val_batch_num | |
| # Logging | |
| val_loss /= val_num | |
| tb.add_scalar('{}/loss'.format(tag), val_loss, iteration) | |
| if isinstance(criterion, nn.BCEWithLogitsLoss): | |
| val_labels = np.concatenate(labels_list) | |
| val_pred = np.concatenate(pred_list) | |
| val_roc_auc = roc_auc_score(val_labels, val_pred) | |
| tb.add_scalar('{}/roc_auc'.format(tag), val_roc_auc, iteration) | |
| tb.add_pr_curve('{}/pr'.format(tag), val_labels, val_pred, iteration) | |
| return val_loss | |
| def save_model(net: nn.Module, optimizer: optim.Optimizer, | |
| train_loss: float, val_loss: float, | |
| iteration: int, batch_size: int, epoch: int, | |
| path: str): | |
| path = str(path) | |
| state = dict(net=net.state_dict(), | |
| opt=optimizer.state_dict(), | |
| train_loss=train_loss, | |
| val_loss=val_loss, | |
| iteration=iteration, | |
| batch_size=batch_size, | |
| epoch=epoch) | |
| torch.save(state, path) | |
| if __name__ == '__main__': | |
| main() | |