File size: 3,546 Bytes
0788e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import logging
import os
import random
import time

import numpy as np
import torch
from data import create_dataloader
from networks.freqnet import freqnet
from options.test_options import TestOptions
from util import printSet
from validate import validate


def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = False


seed_torch(100)


DetectionTests = {
    # 'AIGIBench': { 'dataroot': '/data/ziqiang/Benchmark',
    #                  'no_resize': False, # Due to the different shapes of images in the dataset, resizing is required during batch detection.
    #                  'no_crop': True,
    #                },
    'jpeg': {
        'dataroot': '/data/ziqiang/jpeg',
        'no_resize': False,  # Due to the different shapes of images in the dataset, resizing is required during batch detection.
        'no_crop': True,
    },
    'noise': {
        'dataroot': '/data/ziqiang/noise',
        'no_resize': False,  # Due to the different shapes of images in the dataset, resizing is required during batch detection.
        'no_crop': True,
    },
    'sample': {
        'dataroot': '/data/ziqiang/sample',
        'no_resize': False,  # Due to the different shapes of images in the dataset, resizing is required during batch detection.
        'no_crop': True,
    },
}

# Set up logging
logging.basicConfig(
    filename='log_test.log', level=logging.INFO, format='%(asctime)s %(message)s'
)
logger = logging.getLogger()

opt = TestOptions().parse(print_options=False)
log_msg = f'Model_path {opt.model_path}'
print(log_msg)
logger.info(log_msg)

# Get model
model = freqnet(num_classes=1)
model.load_state_dict(torch.load(opt.model_path), strict=True)
model.cuda()
model.eval()

for testSet in DetectionTests.keys():
    dataroot = DetectionTests[testSet]['dataroot']
    printSet(testSet)
    logger.info(testSet)

    accs = []
    aps = []
    r_accs = []
    f_accs = []
    current_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    print(current_time)
    logger.info(current_time)

    for v_id, val in enumerate(os.listdir(dataroot)):
        opt.dataroot = '{}/{}'.format(dataroot, val)
        opt.classes = ''  # os.listdir(opt.dataroot) if multiclass[v_id] else ['']
        opt.no_resize = DetectionTests[testSet]['no_resize']
        opt.no_crop = DetectionTests[testSet]['no_crop']
        dataloader = create_dataloader(opt)
        acc, ap, r_acc, f_acc = validate(model, dataloader)
        accs.append(acc)
        aps.append(ap)
        r_accs.append(r_acc)
        f_accs.append(f_acc)
        log_msg = (
            '({} {:12}) acc: {:.1f}; ap: {:.1f}; r_acc: {:.1f}; f_acc: {:.1f}'.format(
                v_id, val, acc * 100, ap * 100, r_acc * 100, f_acc * 100
            )
        )
        print(log_msg)
        logger.info(log_msg)

    mean_acc = np.array(accs).mean() * 100
    mean_ap = np.array(aps).mean() * 100
    mean_r_acc = np.array(r_accs).mean() * 100
    mean_f_acc = np.array(f_accs).mean() * 100
    log_msg = '({} {:10}) acc: {:.1f}; ap: {:.1f}; r_acc: {:.1f}; f_acc: {:.1f}'.format(
        v_id + 1, 'Mean', mean_acc, mean_ap, mean_r_acc, mean_f_acc
    )
    print(log_msg)
    logger.info(log_msg)
    print('*' * 25)
    logger.info('*' * 25)