Spaces:
Runtime error
Runtime error
| ''' | |
| This file contains functions to test and save the results | |
| ''' | |
| from __future__ import print_function | |
| import os | |
| import argparse | |
| import torch.backends.cudnn as cudnn | |
| from ssd import build_ssd | |
| from utils import draw_boxes, helpers, save_boxes | |
| import logging | |
| import time | |
| import datetime | |
| from torch.autograd import Variable | |
| from torchvision import datasets, transforms | |
| from torch.utils.data import Dataset, DataLoader | |
| from data import * | |
| import shutil | |
| import torch.nn as nn | |
| def test_net_batch(args, net, gpu_id, dataset, transform, thresh): | |
| ''' | |
| Batch testing | |
| ''' | |
| num_images = len(dataset) | |
| if args.limit != -1: | |
| num_images = args.limit | |
| data_loader = DataLoader(dataset, args.batch_size, | |
| num_workers=args.num_workers, | |
| shuffle=False, collate_fn=detection_collate, | |
| pin_memory=True) | |
| total = len(dataset) | |
| logging.debug('Test dataset size is {}'.format(total)) | |
| done = 0 | |
| for batch_idx, (images, targets, metadata) in enumerate(data_loader): | |
| done = done + len(images) | |
| logging.debug('processing {}/{}'.format(done, total)) | |
| if args.cuda: | |
| images = images.cuda() | |
| targets = [ann.cuda() for ann in targets] | |
| else: | |
| images = Variable(images) | |
| targets = [Variable(ann, volatile=True) for ann in targets] | |
| # targets = [ann for ann in targets] | |
| y, debug_boxes, debug_scores = net(images) # forward pass | |
| detections = y.data | |
| k = 0 | |
| for img, meta in zip(images, metadata): | |
| img_id = meta[0] | |
| x_l = meta[1] | |
| y_l = meta[2] | |
| img = img.permute(1,2,0) | |
| # scale each detection back up to the image | |
| scale = torch.Tensor([img.shape[1], img.shape[0], | |
| img.shape[1], img.shape[0]]) | |
| recognized_boxes = [] | |
| recognized_scores = [] | |
| # [1,2,200,5] | |
| # we only care about math class | |
| # hence select detections[image_id, class, detection_id, detection_score] | |
| # class=1 for math | |
| i = 1 | |
| j = 0 | |
| while j < detections.size(2) and detections[k, i, j, 0] >= thresh: # TODO it was 0.6 | |
| score = detections[k, i, j, 0] | |
| pt = (detections[k, i, j, 1:] * args.window).cpu().numpy() | |
| coords = (pt[0] + x_l, pt[1] + y_l, pt[2] + x_l, pt[3] + y_l) | |
| #coords = (pt[0], pt[1], pt[2], pt[3]) | |
| recognized_boxes.append(coords) | |
| recognized_scores.append(score.cpu().numpy()) | |
| j += 1 | |
| save_boxes(args, recognized_boxes, recognized_scores, img_id) | |
| k = k + 1 | |
| if args.verbose: | |
| draw_boxes(args, img.cpu().numpy(), recognized_boxes, recognized_scores, | |
| debug_boxes, debug_scores, scale, img_id) | |
| def test_gtdb(args): | |
| gpu_id = 0 | |
| if args.cuda: | |
| gpu_id = helpers.get_freer_gpu() | |
| torch.cuda.set_device(gpu_id) | |
| # load net | |
| num_classes = 2 # +1 background | |
| # initialize SSD | |
| net = build_ssd(args, 'test', exp_cfg[args.cfg], gpu_id, args.model_type, num_classes) | |
| logging.debug(net) | |
| net.to(gpu_id) | |
| net = nn.DataParallel(net) | |
| net.load_state_dict(torch.load(args.trained_model, map_location={'cuda:1':'cuda:0'})) | |
| net.eval() | |
| logging.debug('Finished loading model!') | |
| dataset = GTDBDetection(args, args.test_data, split='test', | |
| transform=BaseTransform(args.model_type, (246,246,246)), | |
| target_transform=GTDBAnnotationTransform()) | |
| if args.cuda: | |
| net = net.to(gpu_id) | |
| cudnn.benchmark = True | |
| # evaluation | |
| test_net_batch(args, net, gpu_id, dataset, | |
| BaseTransform(args.model_type, (246,246,246)), | |
| thresh=args.visual_threshold) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='Single Shot MultiBox Detection') | |
| parser.add_argument('--trained_model', default='weights/ssd300_GTDB_990.pth', | |
| type=str, help='Trained state_dict file path to open') | |
| parser.add_argument('--save_folder', default='eval/', type=str, | |
| help='Dir to save results') | |
| parser.add_argument('--visual_threshold', default=0.6, type=float, | |
| help='Final confidence threshold') | |
| parser.add_argument('--cuda', default=False, type=bool, | |
| help='Use cuda to train model') | |
| parser.add_argument('--dataset_root', default=GTDB_ROOT, help='Location of VOC root directory') | |
| parser.add_argument('--test_data', default="testing_data", help='testing data file') | |
| parser.add_argument('--verbose', default=False, type=bool, help='plot output') | |
| parser.add_argument('--suffix', default="_10", type=str, help='suffix of directory of images for testing') | |
| parser.add_argument('--exp_name', default="SSD", help='Name of the experiment. Will be used to generate output') | |
| parser.add_argument('--model_type', default=300, type=int, | |
| help='Type of ssd model, ssd300 or ssd512') | |
| parser.add_argument('--use_char_info', default=False, type=bool, help='Whether or not to use char info') | |
| parser.add_argument('--limit', default=-1, type=int, help='limit on number of test examples') | |
| parser.add_argument('--cfg', default="gtdb", type=str, | |
| help='Type of network: either gtdb or math_gtdb_512') | |
| parser.add_argument('--batch_size', default=16, type=int, | |
| help='Batch size for training') | |
| parser.add_argument('--num_workers', default=4, type=int, | |
| help='Number of workers used in data loading') | |
| parser.add_argument('--kernel', default="3 3", type=int, nargs='+', | |
| help='Kernel size for feature layers: 3 3 or 1 5') | |
| parser.add_argument('--padding', default="1 1", type=int, nargs='+', | |
| help='Padding for feature layers: 1 1 or 0 2') | |
| parser.add_argument('--neg_mining', default=True, type=bool, | |
| help='Whether or not to use hard negative mining with ratio 1:3') | |
| parser.add_argument('--log_dir', default="logs", type=str, | |
| help='dir to save the logs') | |
| parser.add_argument('--stride', default=0.1, type=float, | |
| help='Stride to use for sliding window') | |
| parser.add_argument('--window', default=1200, type=int, | |
| help='Sliding window size') | |
| parser.add_argument('-f', default=None, type=str, help="Dummy arg so we can load in Jupyter Notebooks") | |
| args = parser.parse_args() | |
| if args.cuda and torch.cuda.is_available(): | |
| torch.set_default_tensor_type('torch.cuda.FloatTensor') | |
| else: | |
| torch.set_default_tensor_type('torch.FloatTensor') | |
| if not os.path.exists(args.save_folder): | |
| os.mkdir(args.save_folder) | |
| if os.path.exists(os.path.join(args.save_folder, args.exp_name)): | |
| shutil.rmtree(os.path.join(args.save_folder, args.exp_name)) | |
| return args | |
| if __name__ == '__main__': | |
| args = parse_args() | |
| start = time.time() | |
| try: | |
| filepath=os.path.join(args.log_dir, args.exp_name + "_" + str(round(time.time())) + ".log") | |
| print('Logging to ' + filepath) | |
| logging.basicConfig(filename=filepath, | |
| filemode='w', format='%(process)d - %(asctime)s - %(message)s', | |
| datefmt='%d-%b-%y %H:%M:%S', level=logging.DEBUG) | |
| test_gtdb(args) | |
| except Exception as e: | |
| logging.error("Exception occurred", exc_info=True) | |
| end = time.time() | |
| logging.debug('Toal time taken ' + str(datetime.timedelta(seconds=end-start))) | |
| logging.debug("Testing done!") | |