File size: 3,820 Bytes
0788e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import sys
import time
import os
import csv
import torch
import logging
from util import Logger, printSet
from validate import validate
from networks.resnet import resnet50
from options.test_options import TestOptions
import networks.resnet as resnet
from data import create_dataloader
import numpy as np
import random


def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = False


seed_torch(100)


DetectionTests = {
    'ForenSynths': {
        'dataroot': '/home/HDD/yjz/dataset/ForenSynths/test',
        'no_resize': False,  # Due to the different shapes of images in the dataset, resizing is required during batch detection.
        'no_crop': True,
    },
    'UniversalFakeDetect': {
        'dataroot': '/home/HDD/yjz/dataset/UniversalFakeDetect',
        'no_resize': False,  # Due to the different shapes of images in the dataset, resizing is required during batch detection.
        'no_crop': True,
    },
    'Genimage': {
        'dataroot': '/home/HDD/yjz/dataset/Genimage',
        'no_resize': False,
        # Due to the different shapes of images in the dataset, resizing is required during batch detection.
        'no_crop': True,
    },
    'AIGIBench': {
        'dataroot': '/home/HDD/yjz/dataset/AIGIBench',
        'no_resize': False,
        # Due to the different shapes of images in the dataset, resizing is required during batch detection.
        'no_crop': True,
    },
}

# Set up logging
logging.basicConfig(
    filename='log_test.log', level=logging.INFO, format='%(asctime)s %(message)s'
)
logger = logging.getLogger()

opt = TestOptions().parse(print_options=False)
log_msg = f'Model_path {opt.model_path}'
print(log_msg)
logger.info(log_msg)

# Get model
model = resnet50(num_classes=1)
model.load_state_dict(torch.load(opt.model_path), strict=True)
model.cuda()
model.eval()

for testSet in DetectionTests.keys():
    dataroot = DetectionTests[testSet]['dataroot']
    printSet(testSet)
    logger.info(testSet)

    accs = []
    aps = []
    r_accs = []
    f_accs = []
    current_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    print(current_time)
    logger.info(current_time)

    for v_id, val in enumerate(os.listdir(dataroot)):
        opt.dataroot = '{}/{}'.format(dataroot, val)
        opt.classes = ''  # os.listdir(opt.dataroot) if multiclass[v_id] else ['']
        opt.no_resize = DetectionTests[testSet]['no_resize']
        opt.no_crop = DetectionTests[testSet]['no_crop']
        dataloader = create_dataloader(opt)
        validate(model, dataloader)
        acc, ap, r_acc, f_acc = validate(model, dataloader)
        accs.append(acc)
        aps.append(ap)
        r_accs.append(r_acc)
        f_accs.append(f_acc)
        log_msg = (
            '({} {:12}) acc: {:.1f}; ap: {:.1f}; r_acc: {:.1f}; f_acc: {:.1f}'.format(
                v_id, val, acc * 100, ap * 100, r_acc * 100, f_acc * 100
            )
        )
        print(log_msg)
        logger.info(log_msg)

    mean_acc = np.array(accs).mean() * 100
    mean_ap = np.array(aps).mean() * 100
    mean_r_acc = np.array(r_accs).mean() * 100
    mean_f_acc = np.array(f_accs).mean() * 100

    log_msg = '({} {:10}) acc: {:.1f}; ap: {:.1f}; r_acc: {:.1f} f_acc: {:.1f}'.format(
        v_id + 1, 'Mean', mean_acc, mean_ap, mean_r_acc, mean_f_acc
    )
    print(log_msg)
    logger.info(log_msg)
    print('*' * 25)
    logger.info('*' * 25)