| """ |
| eval pretained model with multi-GPU support. |
| """ |
| import os |
| import numpy as np |
| from os.path import join |
| import cv2 |
| import random |
| import datetime |
| import time |
| import yaml |
| import pickle |
| from tqdm import tqdm |
| from copy import deepcopy |
| from PIL import Image as pil_image |
| from metrics.utils import get_test_metrics |
| import torch |
| import torch.nn as nn |
| import torch.nn.parallel |
| import torch.backends.cudnn as cudnn |
| import torch.nn.functional as F |
| import torch.utils.data |
| import torch.optim as optim |
| import torch.distributed as dist |
| from torch.utils.data.distributed import DistributedSampler |
|
|
| from dataset.abstract_dataset import DeepfakeAbstractBaseDataset |
| from dataset.ff_blend import FFBlendDataset |
| from dataset.fwa_blend import FWABlendDataset |
| from dataset.pair_dataset import pairDataset |
|
|
| from trainer.trainer import Trainer |
| from detectors import DETECTOR |
| from metrics.base_metrics_class import Recorder, calculate_acc_for_test |
| import metrics_retrieval.utils |
| from metrics_retrieval.get_metric_pro4 import * |
| from metrics_retrieval.get_metric import * |
|
|
| from collections import defaultdict |
|
|
| import argparse |
| from logger import create_logger |
|
|
| parser = argparse.ArgumentParser(description='Process some paths.') |
| parser.add_argument('--detector_path', type=str, default='/PATH/TO/resnet34.yaml', help='path to detector YAML file') |
| parser.add_argument("--test_dataset", nargs="+") |
| parser.add_argument('--weights_path', type=str, default='') |
| parser.add_argument('--ddp', action='store_true', help='Use DistributedDataParallel') |
| parser.add_argument('--use_latest', action='store_true', help='Use Latest Ckpt') |
| parser.add_argument('--local_rank', '--local-rank', type=int, default=-1, help='Local rank for DDP') |
| parser.add_argument('--test_config', type=str, default='test_config_p2.yaml', help='test_config_p2.yaml / test_config_p4.yaml') |
| args = parser.parse_args() |
|
|
|
|
| def init_seed(config, seed=None): |
| if seed is None: |
| if config['manualSeed'] is None: |
| config['manualSeed'] = random.randint(1, 10000) |
| seed = config['manualSeed'] |
| random.seed(seed) |
| torch.manual_seed(seed) |
| if config['cuda']: |
| torch.cuda.manual_seed_all(seed) |
| return seed |
|
|
|
|
| def prepare_testing_data(config, ddp=False): |
| def get_test_data_loader(config, test_name): |
| |
| config = config.copy() |
| config['test_dataset'] = test_name |
| test_set = DeepfakeAbstractBaseDataset( |
| config=config, |
| mode='test', |
| ) |
|
|
| |
| sampler = DistributedSampler(test_set, shuffle=False) if ddp else None |
|
|
| test_data_loader = \ |
| torch.utils.data.DataLoader( |
| dataset=test_set, |
| batch_size=config['test_batchSize'], |
| shuffle=(sampler is None), |
| num_workers=int(config['workers']), |
| collate_fn=test_set.collate_fn, |
| drop_last=False, |
| pin_memory=True, |
| sampler=sampler |
| ) |
| return test_data_loader, test_set.data_dict |
|
|
| test_data_loaders = {} |
| test_data_dicts = {} |
| for one_test_name in config['test_dataset']: |
| loader, data_dict = get_test_data_loader(config, one_test_name) |
| test_data_loaders[one_test_name] = loader |
| test_data_dicts[one_test_name] = data_dict |
| return test_data_loaders, test_data_dicts |
|
|
|
|
| def choose_metric(config): |
| metric_scoring = config['metric_scoring'] |
| if metric_scoring not in ['eer', 'auc', 'acc', 'ap']: |
| raise NotImplementedError('metric {} is not implemented'.format(metric_scoring)) |
| return metric_scoring |
|
|
|
|
| def test_one_dataset(model, data_loader, device, local_rank): |
| |
| prediction_lists = [] |
| feature_lists = [] |
| label_lists = [] |
| img_name_lists = [] |
|
|
| |
| pbar = tqdm(enumerate(data_loader), total=len(data_loader), disable=(local_rank != 0)) |
|
|
| for i, data_dict in pbar: |
| |
| data, label, mask, landmark = data_dict['image'], data_dict['label'], data_dict['mask'], data_dict['landmark'] |
| img_names =[] |
|
|
| |
| data_dict['image'], data_dict['label'] = data.to(device), label.to(device) |
| if mask is not None: |
| data_dict['mask'] = mask.to(device) |
| if landmark is not None: |
| data_dict['landmark'] = landmark.to(device) |
|
|
| |
| predictions = inference(model, data_dict) |
|
|
| |
| label_lists.append(data_dict['label']) |
| prediction_lists.append(predictions['prob']) |
| feature_lists.append(predictions['feat']) |
| img_name_lists.extend(img_names) |
|
|
| |
| predictions_tensor = torch.cat(prediction_lists, dim=0) if prediction_lists else torch.tensor([], device=device) |
| labels_tensor = torch.cat(label_lists, dim=0) if label_lists else torch.tensor([], device=device) |
| feats_tensor = torch.cat(feature_lists, dim=0) if feature_lists else torch.tensor([], device=device) |
|
|
| print("feats_tensor", feats_tensor.shape) |
|
|
| |
| return predictions_tensor, labels_tensor, feats_tensor, img_name_lists |
|
|
|
|
| def test_epoch(model, test_data_loaders, test_data_dicts, device, local_rank, ddp, config, logger): |
| |
| model.eval() |
|
|
| |
| metrics_all_datasets = {} |
|
|
| |
| keys = test_data_loaders.keys() |
| for key in keys: |
|
|
| |
| print("Run Dataset:", key) |
| |
| logger.info(f"--------------- Run Dataset: {key} ---------------") |
| logger.info(f"--------------- Run Dataset: {logger.log_path} ---------------") |
|
|
| data_loader = test_data_loaders[key] |
| data_dict = test_data_dicts[key] |
|
|
| |
| if ddp and hasattr(data_loader.sampler, 'set_epoch'): |
| data_loader.sampler.set_epoch(0) |
|
|
| |
| predictions_tensor, labels_tensor, feats_tensor, img_names = test_one_dataset( |
| model, data_loader, device, local_rank) |
|
|
| |
| if ddp: |
| world_size = dist.get_world_size() |
|
|
| |
| all_predictions = [torch.zeros_like(predictions_tensor) for _ in range(world_size)] |
| dist.all_gather(all_predictions, predictions_tensor) |
|
|
| |
| all_labels = [torch.zeros_like(labels_tensor) for _ in range(world_size)] |
| dist.all_gather(all_labels, labels_tensor) |
|
|
| |
| all_feats = [torch.zeros_like(feats_tensor) for _ in range(world_size)] |
| dist.all_gather(all_feats, feats_tensor) |
|
|
| all_predictions = torch.cat(all_predictions, dim=0) |
| all_labels = torch.cat(all_labels, dim=0) |
| all_feats = torch.cat(all_feats, dim=0) |
| else: |
| |
| all_predictions = predictions_tensor.cpu().numpy() |
| all_labels = labels_tensor.cpu().numpy() |
| all_feats = feats_tensor.cpu().numpy() |
| |
|
|
| |
| if local_rank == 0: |
| |
| metric_one_dataset = calculate_acc_for_test(all_labels, all_predictions, config['backbone_config']['num_classes']) |
| metrics_all_datasets[key] = metric_one_dataset |
|
|
| |
| tqdm.write(f"dataset: {key}") |
| for k, v in metric_one_dataset.items(): |
| tqdm.write(f"{k}: {v}") |
| logger.info(f"{k}: {v}") |
|
|
| |
| pkl_save_path = os.path.join(os.path.dirname(logger.log_path), f"{key}.pkl") |
| save_data = { |
| "all_predictions": all_predictions.cpu().numpy(), |
| "all_labels": all_labels.cpu().numpy(), |
| "all_feats": all_feats.cpu().numpy(), |
| "metrics": metric_one_dataset, |
| "all_names": img_names |
| } |
| with open(pkl_save_path, "wb") as f: |
| pickle.dump(save_data, f, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
| return metrics_all_datasets if local_rank == 0 else None |
|
|
|
|
| @torch.no_grad() |
| def inference(model, data_dict): |
| from torch.cuda.amp import autocast |
| with autocast(dtype=torch.float16): |
| predictions = model(data_dict, inference=True) |
| return predictions |
|
|
|
|
| def main(): |
| |
| ddp = args.ddp |
| local_rank = args.local_rank |
|
|
| if ddp: |
| |
| torch.cuda.set_device(local_rank) |
| dist.init_process_group( |
| backend='nccl', |
| init_method='env://', |
| world_size=int(os.environ.get("WORLD_SIZE", 1)), |
| rank=int(os.environ.get("RANK", 0)) |
| ) |
| device = torch.device("cuda", local_rank) |
| else: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| |
| with open(args.detector_path, 'r') as f: |
| config = yaml.safe_load(f) |
| |
| with open(f'./training/config/{args.test_config}', 'r') as f: |
| config_base = yaml.safe_load(f) |
|
|
| |
| if 'label_dict' in config: |
| config_base['label_dict']=config['label_dict'] |
| config.update(config_base) |
|
|
| weights_path = None |
| |
| if args.test_dataset: |
| config['test_dataset'] = args.test_dataset |
| if args.weights_path: |
| config['weights_path'] = args.weights_path |
| weights_path = args.weights_path |
|
|
| |
| seed = init_seed(config) |
| if ddp: |
| |
| seed += dist.get_rank() |
| init_seed(config, seed) |
|
|
| |
| if config['cudnn']: |
| cudnn.benchmark = True |
|
|
| |
| logs_test_dir = weights_path.replace("logs", "logs_test") |
| |
| |
| os.makedirs(logs_test_dir, exist_ok=True) |
| logger = create_logger(os.path.join(logs_test_dir, 'testing.log')) |
| logger.info('Save log to {}'.format(logs_test_dir)) |
| |
| logger.info("--------------- Configuration ---------------") |
| params_string = "Parameters: \n" |
| for key, value in config.items(): |
| params_string += "{}: {}".format(key, value) + "\n" |
| logger.info(params_string) |
|
|
| |
| test_data_loaders, test_data_dicts = prepare_testing_data(config, ddp) |
|
|
| |
| model_class = DETECTOR[config['model_name']] |
| model = model_class(config).to(device) |
| epoch = 0 |
|
|
| |
| if local_rank == 0: |
| for name, param in model.named_parameters(): |
| print(f"{name}: {param.shape}") |
|
|
| if weights_path: |
| |
| if 'lora' in config['model_name'].lower() or "pmoe" in config['model_name'].lower(): |
| model.eval() |
|
|
| if weights_path: |
| try: |
| epoch = int(weights_path.split('/')[-1].split('.')[0].split('_')[2]) |
| except: |
| epoch = 0 |
|
|
| |
| if args.use_latest: |
| ckpt_path = os.path.join(weights_path, "test/protocol_2_test/ckpt_latest.pth") |
| else: |
| if weights_path[-3:] == "pth": |
| ckpt_path = weights_path |
| else: |
| ckpt_path = os.path.join(weights_path, "test/protocol_2_test/ckpt_best.pth") |
| |
| ckpt = torch.load(ckpt_path, map_location=f"cuda:{local_rank}") |
| logger.info(f"Load ckpt: {ckpt_path}") |
|
|
| |
| new_state_dict = {k.replace('module.', ''): v for k, v in ckpt.items()} |
|
|
| model.load_state_dict(new_state_dict, strict=False) |
|
|
| if local_rank == 0: |
| print('===> Load checkpoint done!') |
| else: |
| if local_rank == 0: |
| print('Fail to load the pre-trained weights') |
|
|
| if ddp: |
| model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) |
|
|
| |
| best_metric = test_epoch(model, test_data_loaders, test_data_dicts, device, local_rank, ddp, config, logger) |
|
|
| if local_rank == 0: |
| print('===> Test Done!') |
|
|
| |
| if ddp: |
| dist.barrier() |
| dist.destroy_process_group() |
|
|
| if local_rank==0: |
| |
| for prot in config['test_dataset']: |
| prefix=config["weights_path"].replace("logs/", "logs_test/") |
| prefix=os.path.join("/Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench_DFG",prefix) |
| if "protocol_2" in prot: |
| pkl_file="protocol_2_test.pkl" |
| RANK_MAX = 10 |
| seed = 42 |
| PKL_FILE_PATH = os.path.join(prefix, pkl_file) |
| run_retrieval_evaluation(pkl_file_path=PKL_FILE_PATH, query_mode='10_sample_avg', rank_max=RANK_MAX,random_seed=seed) |
| elif "protocol_3" in prot: |
| pkl_file="protocol_3_test.pkl" |
| RANK_MAX = 10 |
| seed = 42 |
| PKL_FILE_PATH = os.path.join(prefix, pkl_file) |
| run_retrieval_evaluation(pkl_file_path=PKL_FILE_PATH, query_mode='10_sample_avg', rank_max=RANK_MAX,random_seed=seed) |
| elif "protocol_4" in prot: |
| pkl_file="protocol_4_test.pkl" |
| RANK_MAX = 10 |
| seed = 42 |
| PKL_FILE_PATH = os.path.join(prefix, pkl_file) |
| yaml_path="config/test_config_p4.yaml" |
| run_retrieval_evaluation_p4(pkl_file_path=PKL_FILE_PATH, query_mode='10_sample_avg', rank_max=RANK_MAX,random_seed=seed,yaml_path=yaml_path) |
|
|
|
|
|
|
|
|
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|
|
|
| |
|
|
| |
|
|
| |
|
|
| |
|
|
|
|
|
|