File size: 3,653 Bytes
0788e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import sys
import time
import os
import csv
import torch
import logging
from util import Logger, printSet
from validate import validate
from networks.resnet import resnet50
from options.test_options import TestOptions
import networks.resnet as resnet
from data import create_dataloader
import numpy as np
import random

def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = False

seed_torch(100)

DetectionTests = {
    # 'AIGIBench': { 'dataroot': '/data/ziqiang/Benchmark-LGrad',
    #              'no_resize': False, # Due to the different shapes of images in the dataset, resizing is required during batch detection.
    #              'no_crop': True,
    #            },
    'jpeg': { 'dataroot': '/data/ziqiang/jpeg-LGrad',
             'no_resize': False, # Due to the different shapes of images in the dataset, resizing is required during batch detection.
             'no_crop': True,
           },
    'noise': { 'dataroot': '/data/ziqiang/noise-LGrad',
                 'no_resize': False, # Due to the different shapes of images in the dataset, resizing is required during batch detection.
                 'no_crop': True,
               },
    'sample': { 'dataroot': '/data/ziqiang/sample-LGrad',
                 'no_resize': False, # Due to the different shapes of images in the dataset, resizing is required during batch detection.
                 'no_crop': True,
               },
}

# Set up logging
logging.basicConfig(filename='log_test.log', level=logging.INFO, format='%(asctime)s %(message)s')
logger = logging.getLogger()

opt = TestOptions().parse(print_options=False)
log_msg = f'Model_path {opt.model_path}'
print(log_msg)
logger.info(log_msg)

# Get model
model = resnet50(num_classes=1)
model.load_state_dict(torch.load(opt.model_path, map_location=torch.device('cuda:2')), strict=True)
model.cuda()
model.eval()

for testSet in DetectionTests.keys():
    dataroot = DetectionTests[testSet]['dataroot']
    printSet(testSet)
    logger.info(testSet)

    accs = []
    aps = []
    r_accs = []
    f_accs = []
    current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    print(current_time)
    logger.info(current_time)

    for v_id, val in enumerate(os.listdir(dataroot)):
        opt.dataroot = '{}/{}'.format(dataroot, val)
        opt.classes = '' # os.listdir(opt.dataroot) if multiclass[v_id] else ['']
        opt.no_resize = DetectionTests[testSet]['no_resize']
        opt.no_crop = DetectionTests[testSet]['no_crop']
        dataloader = create_dataloader(opt)
        acc, ap, r_acc, f_acc = validate(model, dataloader)
        accs.append(acc)
        aps.append(ap)
        r_accs.append(r_acc)
        f_accs.append(f_acc)
        log_msg = "({} {:12}) acc: {:.1f}; ap: {:.1f}; r_acc: {:.1f}; f_acc: {:.1f}" \
            .format(v_id, val, acc * 100, ap * 100, r_acc * 100, f_acc * 100)
        print(log_msg)
        logger.info(log_msg)

    mean_acc = np.array(accs).mean() * 100
    mean_ap = np.array(aps).mean() * 100
    mean_r_acc = np.array(r_accs).mean() * 100
    mean_f_acc = np.array(f_accs).mean() * 100
    log_msg = "({} {:10}) acc: {:.1f}; ap: {:.1f}; r_acc: {:.1f}; f_acc: {:.1f}" \
        .format(v_id + 1, 'Mean', mean_acc, mean_ap, mean_r_acc, mean_f_acc)
    print(log_msg)
    logger.info(log_msg)
    print('*' * 25)
    logger.info('*' * 25)