import sys import time import os import csv import torch import logging from util import Logger, printSet from validate import validate from networks.resnet import resnet50 from options.test_options import TestOptions import networks.resnet as resnet from data import create_dataloader import numpy as np import random def seed_torch(seed=1029): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.backends.cudnn.enabled = False seed_torch(100) DetectionTests = { 'ForenSynths': { 'dataroot': '/home/HDD/yjz/dataset/ForenSynths/test', 'no_resize': False, # Due to the different shapes of images in the dataset, resizing is required during batch detection. 'no_crop': True, }, 'UniversalFakeDetect': { 'dataroot': '/home/HDD/yjz/dataset/UniversalFakeDetect', 'no_resize': False, # Due to the different shapes of images in the dataset, resizing is required during batch detection. 'no_crop': True, }, 'Genimage': { 'dataroot': '/home/HDD/yjz/dataset/Genimage', 'no_resize': False, # Due to the different shapes of images in the dataset, resizing is required during batch detection. 'no_crop': True, }, 'AIGIBench': { 'dataroot': '/home/HDD/yjz/dataset/AIGIBench', 'no_resize': False, # Due to the different shapes of images in the dataset, resizing is required during batch detection. 'no_crop': True, }, } # Set up logging logging.basicConfig( filename='log_test.log', level=logging.INFO, format='%(asctime)s %(message)s' ) logger = logging.getLogger() opt = TestOptions().parse(print_options=False) log_msg = f'Model_path {opt.model_path}' print(log_msg) logger.info(log_msg) # Get model model = resnet50(num_classes=1) model.load_state_dict(torch.load(opt.model_path), strict=True) model.cuda() model.eval() for testSet in DetectionTests.keys(): dataroot = DetectionTests[testSet]['dataroot'] printSet(testSet) logger.info(testSet) accs = [] aps = [] r_accs = [] f_accs = [] current_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) print(current_time) logger.info(current_time) for v_id, val in enumerate(os.listdir(dataroot)): opt.dataroot = '{}/{}'.format(dataroot, val) opt.classes = '' # os.listdir(opt.dataroot) if multiclass[v_id] else [''] opt.no_resize = DetectionTests[testSet]['no_resize'] opt.no_crop = DetectionTests[testSet]['no_crop'] dataloader = create_dataloader(opt) validate(model, dataloader) acc, ap, r_acc, f_acc = validate(model, dataloader) accs.append(acc) aps.append(ap) r_accs.append(r_acc) f_accs.append(f_acc) log_msg = ( '({} {:12}) acc: {:.1f}; ap: {:.1f}; r_acc: {:.1f}; f_acc: {:.1f}'.format( v_id, val, acc * 100, ap * 100, r_acc * 100, f_acc * 100 ) ) print(log_msg) logger.info(log_msg) mean_acc = np.array(accs).mean() * 100 mean_ap = np.array(aps).mean() * 100 mean_r_acc = np.array(r_accs).mean() * 100 mean_f_acc = np.array(f_accs).mean() * 100 log_msg = '({} {:10}) acc: {:.1f}; ap: {:.1f}; r_acc: {:.1f} f_acc: {:.1f}'.format( v_id + 1, 'Mean', mean_acc, mean_ap, mean_r_acc, mean_f_acc ) print(log_msg) logger.info(log_msg) print('*' * 25) logger.info('*' * 25)