math2tex / ScanSSD /test.py
duycse1603's picture
[Add] source
6163604
'''
This file contains functions to test and save the results
'''
from __future__ import print_function
import os
import argparse
import torch.backends.cudnn as cudnn
from ssd import build_ssd
from utils import draw_boxes, helpers, save_boxes
import logging
import time
import datetime
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
from data import *
import shutil
import torch.nn as nn
def test_net_batch(args, net, gpu_id, dataset, transform, thresh):
'''
Batch testing
'''
num_images = len(dataset)
if args.limit != -1:
num_images = args.limit
data_loader = DataLoader(dataset, args.batch_size,
num_workers=args.num_workers,
shuffle=False, collate_fn=detection_collate,
pin_memory=True)
total = len(dataset)
logging.debug('Test dataset size is {}'.format(total))
done = 0
for batch_idx, (images, targets, metadata) in enumerate(data_loader):
done = done + len(images)
logging.debug('processing {}/{}'.format(done, total))
if args.cuda:
images = images.cuda()
targets = [ann.cuda() for ann in targets]
else:
images = Variable(images)
targets = [Variable(ann, volatile=True) for ann in targets]
# targets = [ann for ann in targets]
y, debug_boxes, debug_scores = net(images) # forward pass
detections = y.data
k = 0
for img, meta in zip(images, metadata):
img_id = meta[0]
x_l = meta[1]
y_l = meta[2]
img = img.permute(1,2,0)
# scale each detection back up to the image
scale = torch.Tensor([img.shape[1], img.shape[0],
img.shape[1], img.shape[0]])
recognized_boxes = []
recognized_scores = []
# [1,2,200,5]
# we only care about math class
# hence select detections[image_id, class, detection_id, detection_score]
# class=1 for math
i = 1
j = 0
while j < detections.size(2) and detections[k, i, j, 0] >= thresh: # TODO it was 0.6
score = detections[k, i, j, 0]
pt = (detections[k, i, j, 1:] * args.window).cpu().numpy()
coords = (pt[0] + x_l, pt[1] + y_l, pt[2] + x_l, pt[3] + y_l)
#coords = (pt[0], pt[1], pt[2], pt[3])
recognized_boxes.append(coords)
recognized_scores.append(score.cpu().numpy())
j += 1
save_boxes(args, recognized_boxes, recognized_scores, img_id)
k = k + 1
if args.verbose:
draw_boxes(args, img.cpu().numpy(), recognized_boxes, recognized_scores,
debug_boxes, debug_scores, scale, img_id)
def test_gtdb(args):
gpu_id = 0
if args.cuda:
gpu_id = helpers.get_freer_gpu()
torch.cuda.set_device(gpu_id)
# load net
num_classes = 2 # +1 background
# initialize SSD
net = build_ssd(args, 'test', exp_cfg[args.cfg], gpu_id, args.model_type, num_classes)
logging.debug(net)
net.to(gpu_id)
net = nn.DataParallel(net)
net.load_state_dict(torch.load(args.trained_model, map_location={'cuda:1':'cuda:0'}))
net.eval()
logging.debug('Finished loading model!')
dataset = GTDBDetection(args, args.test_data, split='test',
transform=BaseTransform(args.model_type, (246,246,246)),
target_transform=GTDBAnnotationTransform())
if args.cuda:
net = net.to(gpu_id)
cudnn.benchmark = True
# evaluation
test_net_batch(args, net, gpu_id, dataset,
BaseTransform(args.model_type, (246,246,246)),
thresh=args.visual_threshold)
def parse_args():
parser = argparse.ArgumentParser(description='Single Shot MultiBox Detection')
parser.add_argument('--trained_model', default='weights/ssd300_GTDB_990.pth',
type=str, help='Trained state_dict file path to open')
parser.add_argument('--save_folder', default='eval/', type=str,
help='Dir to save results')
parser.add_argument('--visual_threshold', default=0.6, type=float,
help='Final confidence threshold')
parser.add_argument('--cuda', default=False, type=bool,
help='Use cuda to train model')
parser.add_argument('--dataset_root', default=GTDB_ROOT, help='Location of VOC root directory')
parser.add_argument('--test_data', default="testing_data", help='testing data file')
parser.add_argument('--verbose', default=False, type=bool, help='plot output')
parser.add_argument('--suffix', default="_10", type=str, help='suffix of directory of images for testing')
parser.add_argument('--exp_name', default="SSD", help='Name of the experiment. Will be used to generate output')
parser.add_argument('--model_type', default=300, type=int,
help='Type of ssd model, ssd300 or ssd512')
parser.add_argument('--use_char_info', default=False, type=bool, help='Whether or not to use char info')
parser.add_argument('--limit', default=-1, type=int, help='limit on number of test examples')
parser.add_argument('--cfg', default="gtdb", type=str,
help='Type of network: either gtdb or math_gtdb_512')
parser.add_argument('--batch_size', default=16, type=int,
help='Batch size for training')
parser.add_argument('--num_workers', default=4, type=int,
help='Number of workers used in data loading')
parser.add_argument('--kernel', default="3 3", type=int, nargs='+',
help='Kernel size for feature layers: 3 3 or 1 5')
parser.add_argument('--padding', default="1 1", type=int, nargs='+',
help='Padding for feature layers: 1 1 or 0 2')
parser.add_argument('--neg_mining', default=True, type=bool,
help='Whether or not to use hard negative mining with ratio 1:3')
parser.add_argument('--log_dir', default="logs", type=str,
help='dir to save the logs')
parser.add_argument('--stride', default=0.1, type=float,
help='Stride to use for sliding window')
parser.add_argument('--window', default=1200, type=int,
help='Sliding window size')
parser.add_argument('-f', default=None, type=str, help="Dummy arg so we can load in Jupyter Notebooks")
args = parser.parse_args()
if args.cuda and torch.cuda.is_available():
torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
torch.set_default_tensor_type('torch.FloatTensor')
if not os.path.exists(args.save_folder):
os.mkdir(args.save_folder)
if os.path.exists(os.path.join(args.save_folder, args.exp_name)):
shutil.rmtree(os.path.join(args.save_folder, args.exp_name))
return args
if __name__ == '__main__':
args = parse_args()
start = time.time()
try:
filepath=os.path.join(args.log_dir, args.exp_name + "_" + str(round(time.time())) + ".log")
print('Logging to ' + filepath)
logging.basicConfig(filename=filepath,
filemode='w', format='%(process)d - %(asctime)s - %(message)s',
datefmt='%d-%b-%y %H:%M:%S', level=logging.DEBUG)
test_gtdb(args)
except Exception as e:
logging.error("Exception occurred", exc_info=True)
end = time.time()
logging.debug('Toal time taken ' + str(datetime.timedelta(seconds=end-start)))
logging.debug("Testing done!")