TheKernel01's picture
Sync from GitHub via hub-sync
0788e19 verified
import sys
import time
import os
import csv
import torch
import logging
from util import Logger, printSet
from validate import validate
from networks.resnet import resnet50
from options.test_options import TestOptions
import networks.resnet as resnet
from data import create_dataloader
import numpy as np
import random
def seed_torch(seed=1029):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled = False
seed_torch(100)
DetectionTests = {
'ForenSynths': {
'dataroot': '/home/HDD/yjz/dataset/ForenSynths/test',
'no_resize': False, # Due to the different shapes of images in the dataset, resizing is required during batch detection.
'no_crop': True,
},
'UniversalFakeDetect': {
'dataroot': '/home/HDD/yjz/dataset/UniversalFakeDetect',
'no_resize': False, # Due to the different shapes of images in the dataset, resizing is required during batch detection.
'no_crop': True,
},
'Genimage': {
'dataroot': '/home/HDD/yjz/dataset/Genimage',
'no_resize': False,
# Due to the different shapes of images in the dataset, resizing is required during batch detection.
'no_crop': True,
},
'AIGIBench': {
'dataroot': '/home/HDD/yjz/dataset/AIGIBench',
'no_resize': False,
# Due to the different shapes of images in the dataset, resizing is required during batch detection.
'no_crop': True,
},
}
# Set up logging
logging.basicConfig(
filename='log_test.log', level=logging.INFO, format='%(asctime)s %(message)s'
)
logger = logging.getLogger()
opt = TestOptions().parse(print_options=False)
log_msg = f'Model_path {opt.model_path}'
print(log_msg)
logger.info(log_msg)
# Get model
model = resnet50(num_classes=1)
model.load_state_dict(torch.load(opt.model_path), strict=True)
model.cuda()
model.eval()
for testSet in DetectionTests.keys():
dataroot = DetectionTests[testSet]['dataroot']
printSet(testSet)
logger.info(testSet)
accs = []
aps = []
r_accs = []
f_accs = []
current_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
print(current_time)
logger.info(current_time)
for v_id, val in enumerate(os.listdir(dataroot)):
opt.dataroot = '{}/{}'.format(dataroot, val)
opt.classes = '' # os.listdir(opt.dataroot) if multiclass[v_id] else ['']
opt.no_resize = DetectionTests[testSet]['no_resize']
opt.no_crop = DetectionTests[testSet]['no_crop']
dataloader = create_dataloader(opt)
validate(model, dataloader)
acc, ap, r_acc, f_acc = validate(model, dataloader)
accs.append(acc)
aps.append(ap)
r_accs.append(r_acc)
f_accs.append(f_acc)
log_msg = (
'({} {:12}) acc: {:.1f}; ap: {:.1f}; r_acc: {:.1f}; f_acc: {:.1f}'.format(
v_id, val, acc * 100, ap * 100, r_acc * 100, f_acc * 100
)
)
print(log_msg)
logger.info(log_msg)
mean_acc = np.array(accs).mean() * 100
mean_ap = np.array(aps).mean() * 100
mean_r_acc = np.array(r_accs).mean() * 100
mean_f_acc = np.array(f_accs).mean() * 100
log_msg = '({} {:10}) acc: {:.1f}; ap: {:.1f}; r_acc: {:.1f} f_acc: {:.1f}'.format(
v_id + 1, 'Mean', mean_acc, mean_ap, mean_r_acc, mean_f_acc
)
print(log_msg)
logger.info(log_msg)
print('*' * 25)
logger.info('*' * 25)