""" eval pretained model with multi-GPU support. """ import os import numpy as np from os.path import join import cv2 import random import datetime import time import yaml import pickle from tqdm import tqdm from copy import deepcopy from PIL import Image as pil_image from metrics.utils import get_test_metrics import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.nn.functional as F import torch.utils.data import torch.optim as optim import torch.distributed as dist from torch.utils.data.distributed import DistributedSampler from dataset.abstract_dataset import DeepfakeAbstractBaseDataset from dataset.ff_blend import FFBlendDataset from dataset.fwa_blend import FWABlendDataset from dataset.pair_dataset import pairDataset from trainer.trainer import Trainer from detectors import DETECTOR from metrics.base_metrics_class import Recorder, calculate_acc_for_test import metrics_retrieval.utils from metrics_retrieval.get_metric_pro4 import * from metrics_retrieval.get_metric import * from collections import defaultdict import argparse from logger import create_logger parser = argparse.ArgumentParser(description='Process some paths.') parser.add_argument('--detector_path', type=str, default='/PATH/TO/resnet34.yaml', help='path to detector YAML file') parser.add_argument("--test_dataset", nargs="+") parser.add_argument('--weights_path', type=str, default='') parser.add_argument('--ddp', action='store_true', help='Use DistributedDataParallel') parser.add_argument('--use_latest', action='store_true', help='Use Latest Ckpt') parser.add_argument('--local_rank', '--local-rank', type=int, default=-1, help='Local rank for DDP') parser.add_argument('--test_config', type=str, default='test_config_p2.yaml', help='test_config_p2.yaml / test_config_p4.yaml') args = parser.parse_args() def init_seed(config, seed=None): if seed is None: if config['manualSeed'] is None: config['manualSeed'] = random.randint(1, 10000) seed = config['manualSeed'] random.seed(seed) torch.manual_seed(seed) if config['cuda']: torch.cuda.manual_seed_all(seed) return seed def prepare_testing_data(config, ddp=False): def get_test_data_loader(config, test_name): # update the config dictionary with the specific testing dataset config = config.copy() # create a copy of config to avoid altering the original one config['test_dataset'] = test_name # specify the current test dataset test_set = DeepfakeAbstractBaseDataset( config=config, mode='test', ) # Use DistributedSampler to distribute the data sampler = DistributedSampler(test_set, shuffle=False) if ddp else None test_data_loader = \ torch.utils.data.DataLoader( dataset=test_set, batch_size=config['test_batchSize'], shuffle=(sampler is None), num_workers=int(config['workers']), collate_fn=test_set.collate_fn, drop_last=False, pin_memory=True, sampler=sampler # add sampler ) return test_data_loader, test_set.data_dict test_data_loaders = {} test_data_dicts = {} for one_test_name in config['test_dataset']: loader, data_dict = get_test_data_loader(config, one_test_name) test_data_loaders[one_test_name] = loader test_data_dicts[one_test_name] = data_dict return test_data_loaders, test_data_dicts def choose_metric(config): metric_scoring = config['metric_scoring'] if metric_scoring not in ['eer', 'auc', 'acc', 'ap']: raise NotImplementedError('metric {} is not implemented'.format(metric_scoring)) return metric_scoring def test_one_dataset(model, data_loader, device, local_rank): # Initialize empty lists to store tensors prediction_lists = [] feature_lists = [] label_lists = [] img_name_lists = [] # Only the main process shows the progress bar pbar = tqdm(enumerate(data_loader), total=len(data_loader), disable=(local_rank != 0)) for i, data_dict in pbar: # get data data, label, mask, landmark = data_dict['image'], data_dict['label'], data_dict['mask'], data_dict['landmark'] img_names =[] # data_dict['image'] # Image names are still strings and are stored separately # Move data to GPU (keeping the original logic) data_dict['image'], data_dict['label'] = data.to(device), label.to(device) if mask is not None: data_dict['mask'] = mask.to(device) if landmark is not None: data_dict['landmark'] = landmark.to(device) # Model forward pass (no gradients, original logic unchanged) predictions = inference(model, data_dict) # Use append instead of extend, and concatenate later with torch.cat label_lists.append(data_dict['label']) # label is a tensor, so append it directly prediction_lists.append(predictions['prob']) # prob is the tensor output by the model feature_lists.append(predictions['feat']) # the same applies to feat img_name_lists.extend(img_names) # String lists still use extend # If the current process has no data (an extreme case), return empty tensors to avoid errors predictions_tensor = torch.cat(prediction_lists, dim=0) if prediction_lists else torch.tensor([], device=device) labels_tensor = torch.cat(label_lists, dim=0) if label_lists else torch.tensor([], device=device) feats_tensor = torch.cat(feature_lists, dim=0) if feature_lists else torch.tensor([], device=device) print("feats_tensor", feats_tensor.shape) # Return results in tensor form (image names remain a list) return predictions_tensor, labels_tensor, feats_tensor, img_name_lists def test_epoch(model, test_data_loaders, test_data_dicts, device, local_rank, ddp, config, logger): # set model to eval mode model.eval() # define test recorder metrics_all_datasets = {} # testing for all test data keys = test_data_loaders.keys() for key in keys: # 1.Dataset Name print("Run Dataset:", key) # if args.local_rank == 0: logger.info(f"--------------- Run Dataset: {key} ---------------") logger.info(f"--------------- Run Dataset: {logger.log_path} ---------------") data_loader = test_data_loaders[key] data_dict = test_data_dicts[key] # Set the sampler epoch in DDP mode if ddp and hasattr(data_loader.sampler, 'set_epoch'): data_loader.sampler.set_epoch(0) # Each process computes its own portion (the return values are tensors at this point) predictions_tensor, labels_tensor, feats_tensor, img_names = test_one_dataset( model, data_loader, device, local_rank) # Gather results from all processes (only the main process needs the full results) if ddp: world_size = dist.get_world_size() # 1. Gather predictions all_predictions = [torch.zeros_like(predictions_tensor) for _ in range(world_size)] dist.all_gather(all_predictions, predictions_tensor) # 2. Gather labels all_labels = [torch.zeros_like(labels_tensor) for _ in range(world_size)] dist.all_gather(all_labels, labels_tensor) # 3. Gather features (optional) all_feats = [torch.zeros_like(feats_tensor) for _ in range(world_size)] dist.all_gather(all_feats, feats_tensor) all_predictions = torch.cat(all_predictions, dim=0) all_labels = torch.cat(all_labels, dim=0) all_feats = torch.cat(all_feats, dim=0) else: # In non-DDP mode, convert directly to NumPy (only once) all_predictions = predictions_tensor.cpu().numpy() all_labels = labels_tensor.cpu().numpy() all_feats = feats_tensor.cpu().numpy() #all_img_names = img_names # Only the main process computes metrics and outputs results if local_rank == 0: # compute metric for each dataset metric_one_dataset = calculate_acc_for_test(all_labels, all_predictions, config['backbone_config']['num_classes']) metrics_all_datasets[key] = metric_one_dataset # Information for each dataset tqdm.write(f"dataset: {key}") for k, v in metric_one_dataset.items(): tqdm.write(f"{k}: {v}") logger.info(f"{k}: {v}") # save info pkl_save_path = os.path.join(os.path.dirname(logger.log_path), f"{key}.pkl") save_data = { "all_predictions": all_predictions.cpu().numpy(), "all_labels": all_labels.cpu().numpy(), "all_feats": all_feats.cpu().numpy(), "metrics": metric_one_dataset, # Additionally save metrics for the current dataset to facilitate later analysis "all_names": img_names } with open(pkl_save_path, "wb") as f: pickle.dump(save_data, f, protocol=pickle.HIGHEST_PROTOCOL) # Using the highest protocol is more efficient return metrics_all_datasets if local_rank == 0 else None @torch.no_grad() def inference(model, data_dict): from torch.cuda.amp import autocast with autocast(dtype=torch.float16): predictions = model(data_dict, inference=True) return predictions def main(): # Initialize DDP ddp = args.ddp local_rank = args.local_rank if ddp: # Initialize the process group torch.cuda.set_device(local_rank) dist.init_process_group( backend='nccl', init_method='env://', # Read rendezvous information from environment variables (set automatically by torchrun) world_size=int(os.environ.get("WORLD_SIZE", 1)), # total number of GPUs rank=int(os.environ.get("RANK", 0)) ) device = torch.device("cuda", local_rank) else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # parse options and load config # Model-specific configuration with open(args.detector_path, 'r') as f: config = yaml.safe_load(f) # Unified base configuration with open(f'./training/config/{args.test_config}', 'r') as f: config_base = yaml.safe_load(f) # Label dictionary shared by all datasets if 'label_dict' in config: config_base['label_dict']=config['label_dict'] # The base configuration has the highest priority config.update(config_base) weights_path = None # If arguments are provided, they will overwrite the yaml settings if args.test_dataset: config['test_dataset'] = args.test_dataset if args.weights_path: config['weights_path'] = args.weights_path weights_path = args.weights_path # Set the same seed for DDP seed = init_seed(config) if ddp: # Use a different seed offset for each process to ensure data augmentation diversity seed += dist.get_rank() init_seed(config, seed) # set cudnn benchmark if needed if config['cudnn']: cudnn.benchmark = True # Log information logs_test_dir = weights_path.replace("logs", "logs_test") # if local_rank == 0: # creat log os.makedirs(logs_test_dir, exist_ok=True) logger = create_logger(os.path.join(logs_test_dir, 'testing.log')) logger.info('Save log to {}'.format(logs_test_dir)) # print configuration logger.info("--------------- Configuration ---------------") params_string = "Parameters: \n" for key, value in config.items(): params_string += "{}: {}".format(key, value) + "\n" logger.info(params_string) # prepare the testing data loader test_data_loaders, test_data_dicts = prepare_testing_data(config, ddp) # prepare the model (detector) model_class = DETECTOR[config['model_name']] model = model_class(config).to(device) epoch = 0 # Only print model parameter information on the main process if local_rank == 0: for name, param in model.named_parameters(): print(f"{name}: {param.shape}") if weights_path: # For models containing LoRA, switch to eval mode first to avoid repeatedly stacking weights if 'lora' in config['model_name'].lower() or "pmoe" in config['model_name'].lower(): model.eval() if weights_path: try: epoch = int(weights_path.split('/')[-1].split('.')[0].split('_')[2]) except: epoch = 0 # Automatically find the best checkpoint if args.use_latest: ckpt_path = os.path.join(weights_path, "test/protocol_2_test/ckpt_latest.pth") else: if weights_path[-3:] == "pth": ckpt_path = weights_path else: ckpt_path = os.path.join(weights_path, "test/protocol_2_test/ckpt_best.pth") # ckpt_path = os.path.join(weights_path, "test/protocol_2_test/ckpt_best.pth") ckpt = torch.load(ckpt_path, map_location=f"cuda:{local_rank}") logger.info(f"Load ckpt: {ckpt_path}") # Remove the "module." prefix from the weights (if DDP was used during training) new_state_dict = {k.replace('module.', ''): v for k, v in ckpt.items()} model.load_state_dict(new_state_dict, strict=False) if local_rank == 0: print('===> Load checkpoint done!') else: if local_rank == 0: print('Fail to load the pre-trained weights') if ddp: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) # start testing best_metric = test_epoch(model, test_data_loaders, test_data_dicts, device, local_rank, ddp, config, logger) if local_rank == 0: print('===> Test Done!') # Clean up the DDP process group if ddp: dist.barrier() dist.destroy_process_group() if local_rank==0: #Metric test for prot in config['test_dataset']: prefix=config["weights_path"].replace("logs/", "logs_test/") prefix=os.path.join("/Youtu_Pangu_Security_Public_cq11/shunliwang/DeepFakeBench_DFG",prefix) if "protocol_2" in prot: pkl_file="protocol_2_test.pkl" RANK_MAX = 10 seed = 42 PKL_FILE_PATH = os.path.join(prefix, pkl_file) run_retrieval_evaluation(pkl_file_path=PKL_FILE_PATH, query_mode='10_sample_avg', rank_max=RANK_MAX,random_seed=seed) elif "protocol_3" in prot: pkl_file="protocol_3_test.pkl" RANK_MAX = 10 seed = 42 PKL_FILE_PATH = os.path.join(prefix, pkl_file) run_retrieval_evaluation(pkl_file_path=PKL_FILE_PATH, query_mode='10_sample_avg', rank_max=RANK_MAX,random_seed=seed) elif "protocol_4" in prot: pkl_file="protocol_4_test.pkl" RANK_MAX = 10 seed = 42 PKL_FILE_PATH = os.path.join(prefix, pkl_file) yaml_path="config/test_config_p4.yaml" run_retrieval_evaluation_p4(pkl_file_path=PKL_FILE_PATH, query_mode='10_sample_avg', rank_max=RANK_MAX,random_seed=seed,yaml_path=yaml_path) if __name__ == '__main__': main() # 1.Useful information in the log # 2.Create the log_test directory # 3.Create logger text output # 4.Save features and labels