|
|
|
|
| import argparse
|
| import os
|
| from itertools import chain
|
| import cv2
|
| import tqdm
|
|
|
| from detectron2.config import get_cfg
|
| from detectron2.data import DatasetCatalog, MetadataCatalog, build_detection_train_loader
|
| from detectron2.data import detection_utils as utils
|
| from detectron2.data.build import filter_images_with_few_keypoints
|
| from detectron2.utils.logger import setup_logger
|
| from detectron2.utils.visualizer import Visualizer
|
|
|
|
|
| def setup(args):
|
| cfg = get_cfg()
|
| if args.config_file:
|
| cfg.merge_from_file(args.config_file)
|
| cfg.merge_from_list(args.opts)
|
| cfg.freeze()
|
| return cfg
|
|
|
|
|
| def parse_args(in_args=None):
|
| parser = argparse.ArgumentParser(description="Visualize ground-truth data")
|
| parser.add_argument(
|
| "--source",
|
| choices=["annotation", "dataloader"],
|
| required=True,
|
| help="visualize the annotations or the data loader (with pre-processing)",
|
| )
|
| parser.add_argument("--config-file", metavar="FILE", help="path to config file")
|
| parser.add_argument("--output-dir", default="./", help="path to output directory")
|
| parser.add_argument("--show", action="store_true", help="show output in a window")
|
| parser.add_argument(
|
| "opts",
|
| help="Modify config options using the command-line",
|
| default=None,
|
| nargs=argparse.REMAINDER,
|
| )
|
| return parser.parse_args(in_args)
|
|
|
|
|
| if __name__ == "__main__":
|
| args = parse_args()
|
| logger = setup_logger()
|
| logger.info("Arguments: " + str(args))
|
| cfg = setup(args)
|
|
|
| dirname = args.output_dir
|
| os.makedirs(dirname, exist_ok=True)
|
| metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])
|
|
|
| def output(vis, fname):
|
| if args.show:
|
| print(fname)
|
| cv2.imshow("window", vis.get_image()[:, :, ::-1])
|
| cv2.waitKey()
|
| else:
|
| filepath = os.path.join(dirname, fname)
|
| print("Saving to {} ...".format(filepath))
|
| vis.save(filepath)
|
|
|
| scale = 2.0 if args.show else 1.0
|
| if args.source == "dataloader":
|
| train_data_loader = build_detection_train_loader(cfg)
|
| for batch in train_data_loader:
|
| for per_image in batch:
|
|
|
| img = per_image["image"].permute(1, 2, 0).cpu().detach().numpy()
|
| img = utils.convert_image_to_rgb(img, cfg.INPUT.FORMAT)
|
|
|
| visualizer = Visualizer(img, metadata=metadata, scale=scale)
|
| target_fields = per_image["instances"].get_fields()
|
| labels = [metadata.thing_classes[i] for i in target_fields["gt_classes"]]
|
| vis = visualizer.overlay_instances(
|
| labels=labels,
|
| boxes=target_fields.get("gt_boxes", None),
|
| masks=target_fields.get("gt_masks", None),
|
| keypoints=target_fields.get("gt_keypoints", None),
|
| )
|
| output(vis, str(per_image["image_id"]) + ".jpg")
|
| else:
|
| dicts = list(chain.from_iterable([DatasetCatalog.get(k) for k in cfg.DATASETS.TRAIN]))
|
| if cfg.MODEL.KEYPOINT_ON:
|
| dicts = filter_images_with_few_keypoints(dicts, 1)
|
| for dic in tqdm.tqdm(dicts):
|
| img = utils.read_image(dic["file_name"], "RGB")
|
| visualizer = Visualizer(img, metadata=metadata, scale=scale)
|
| vis = visualizer.draw_dataset_dict(dic)
|
| output(vis, os.path.basename(dic["file_name"]))
|
|
|