Spaces:
Sleeping
Sleeping
File size: 1,999 Bytes
0788e19 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | import logging
import os
import time
import numpy as np
import torch
from eval_config import *
from networks.resnet import resnet50
from options.test_options import TestOptions
from validate import validate
logging.basicConfig(
filename='log_test.log', level=logging.INFO, format='%(asctime)s %(message)s'
)
logger = logging.getLogger()
# Running tests
opt = TestOptions().parse(print_options=False)
model_name = os.path.basename(model_path).replace('.pth', '')
rows = [
['{} model testing on...'.format(model_name)],
['testset', 'accuracy', 'avg precision'],
]
log_msg = f'Model_path {opt.model_path}'
print(log_msg)
logger.info(log_msg)
print('{} model testing on...'.format(model_name))
accs = []
aps = []
r_accs = []
f_accs = []
current_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
print(current_time)
logger.info(current_time)
for v_id, val in enumerate(vals):
opt.dataroot = '{}/{}'.format(dataroot, val)
opt.classes = os.listdir(opt.dataroot) if multiclass[v_id] else ['']
opt.no_resize = True # testing without resizing by default
model = resnet50(num_classes=1)
state_dict = torch.load(model_path, map_location='cpu')
model.load_state_dict(state_dict['model'])
model.cuda()
model.eval()
acc, ap, r_acc, f_acc, _, _ = validate(model, opt)
accs.append(acc)
aps.append(ap)
r_accs.append(r_acc)
f_accs.append(f_acc)
log_msg = '({} {:12}) acc: {:.1f}; ap: {:.1f}; r_acc: {:.1f}; f_acc: {:.1f}'.format(
v_id, val, acc * 100, ap * 100, r_acc * 100, f_acc * 100
)
print(log_msg)
logger.info(log_msg)
mean_acc = np.array(accs).mean() * 100
mean_ap = np.array(aps).mean() * 100
mean_r_acc = np.array(r_accs).mean() * 100
mean_f_acc = np.array(f_accs).mean() * 100
log_msg = '({} {:10}) acc: {:.1f}; ap: {:.1f}; r_acc: {:.1f}; f_acc: {:.1f}'.format(
v_id + 1, 'Mean', mean_acc, mean_ap, mean_r_acc, mean_f_acc
)
print(log_msg)
logger.info(log_msg)
print('*' * 25)
logger.info('*' * 25)
|