| import utils |
| import logging |
| import argparse |
| import importlib |
| import torch |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from PIL import Image |
| from mmcv import Config, DictAction |
| from mmcv.parallel import MMDataParallel |
| from mmcv.runner import load_checkpoint |
| from mmdet.apis import set_random_seed |
| from mmdet3d.datasets import build_dataset, build_dataloader |
| from mmdet3d.models import build_model |
| from models.utils import DUMP, VERSION |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Validate a detector') |
| parser.add_argument('--config', required=True) |
| parser.add_argument('--weights', required=True) |
| parser.add_argument('--override', nargs='+', action=DictAction) |
| parser.add_argument('--score_threshold', default=0.3) |
| parser.add_argument('--stage_id', default=5) |
| parser.add_argument('--num_frames', default=3) |
| parser.add_argument('--num_views', default=6) |
| args = parser.parse_args() |
|
|
| |
| cfgs = Config.fromfile(args.config) |
| if args.override is not None: |
| cfgs.merge_from_dict(args.override) |
|
|
| |
| cfgs.data.val.ann_file = cfgs.data.val.ann_file.replace('val', 'val_mini') |
|
|
| |
| importlib.import_module('models') |
| importlib.import_module('loaders') |
|
|
| |
| from mmcv.utils.logging import logger_initialized |
| logger_initialized['root'] = logging.Logger(__name__, logging.WARNING) |
| logger_initialized['mmcv'] = logging.Logger(__name__, logging.WARNING) |
|
|
| |
| assert torch.cuda.is_available() |
| assert torch.cuda.device_count() == 1 |
|
|
| utils.init_logging(None, cfgs.debug) |
|
|
| logging.info('Using GPU: %s' % torch.cuda.get_device_name(0)) |
| logging.info('Setting random seed: 0') |
| set_random_seed(0, deterministic=True) |
|
|
| logging.info('Loading validation set from %s' % cfgs.data.val.data_root) |
| val_dataset = build_dataset(cfgs.data.val) |
| val_loader = build_dataloader( |
| val_dataset, |
| samples_per_gpu=1, |
| workers_per_gpu=2, |
| num_gpus=1, |
| dist=False, |
| shuffle=False, |
| seed=0, |
| ) |
|
|
| logging.info('Creating model: %s' % cfgs.model.type) |
| model = build_model(cfgs.model) |
| model.cuda() |
| model = MMDataParallel(model, [0]) |
|
|
| logging.info('Loading checkpoint from %s' % args.weights) |
| checkpoint = load_checkpoint( |
| model, args.weights, map_location='cuda', strict=True, |
| logger=logging.Logger(__name__, logging.ERROR) |
| ) |
|
|
| if 'version' in checkpoint: |
| VERSION.name = checkpoint['version'] |
|
|
| for idx, data in enumerate(val_loader): |
| DUMP.enabled = True |
| model.eval() |
|
|
| with torch.no_grad(): |
| model(return_loss=False, rescale=True, **data) |
|
|
| cls_scores = torch.load('{}/cls_score_stage{}.pth'.format(DUMP.out_dir, args.stage_id))[0] |
| cls_scores, cls_ids = torch.max(cls_scores, dim=-1) |
|
|
| |
| query_ids = torch.where(cls_scores > args.score_threshold)[0] |
| cls_scores, cls_ids = cls_scores[query_ids], cls_ids[query_ids] |
|
|
| plt.figure(figsize=(240, 49)) |
| view_mapping = [1, 2, 0, 4, 5, 3] |
|
|
| for frame_id in range(args.num_frames): |
| sample_points_cam = torch.load( |
| '{}/sample_points_cam_stage{}.pth'.format(DUMP.out_dir, args.stage_id) |
| ) |
| valid_mask = torch.load( |
| '{}/sample_points_cam_valid_mask_stage{}.pth'.format(DUMP.out_dir, args.stage_id) |
| ) |
|
|
| for view_id in range(args.num_views): |
| filenames = data['img_metas'][0].data[0][0]['filename'] |
| filename = filenames[frame_id * 6 + view_id] |
|
|
| |
| img = Image.open(filename) |
| img = img.crop((0, 260, 1600, 900)) |
|
|
| |
| plot_id = frame_id * args.num_views + view_mapping[view_id] + 1 |
| ax = plt.subplot(args.num_frames, args.num_views, plot_id) |
| ax.imshow(img) |
| ax.axis('off') |
| ax.set_xlim(0, 1600) |
| ax.set_ylim(640, 0) |
|
|
| |
| for query_id in query_ids: |
| xyz = sample_points_cam[0, frame_id, view_id, query_id].numpy() |
| mask = valid_mask[0, frame_id, view_id, query_id].numpy() |
| mask = np.round(mask).astype(bool) |
|
|
| cx = xyz[:, 0] * 1600 |
| cy = xyz[:, 1] * 640 |
| cz = xyz[:, 2] |
|
|
| cz[np.where(cz <= 0)] = 1e8 |
| cz = np.log(60 / cz ** 0.8) * 2.4 |
| cx, cy, cz = cx[mask], cy[mask], cz[mask] |
|
|
| if len(cz) == 0: |
| continue |
|
|
| ax.scatter(cx, cy, s=4**(cz + 1), alpha=0.7, color='C%d' % (query_id % 5)) |
|
|
| plt.tight_layout() |
| plt.subplots_adjust(hspace=0.01, wspace=0.01) |
| plt.savefig('outputs/sp_%04d.jpg' % idx, dpi=20) |
| plt.close() |
|
|
| logging.info('Visualized result is dumped to outputs/sp_%04d.jpg' % idx) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|