#!/usr/bin/env python # Copyright (c) Facebook, Inc. and its affiliates. import argparse import os from itertools import chain import cv2 import tqdm from detectron2.config import LazyConfig, instantiate from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.data import detection_utils as utils from detectron2.utils.logger import setup_logger from detectron2.utils.visualizer import Visualizer def setup(args): cfg = LazyConfig.load(args.config_file) cfg = LazyConfig.apply_overrides(cfg, args.opts) cfg.dataloader.train.num_workers = 0 return cfg def parse_args(in_args=None): parser = argparse.ArgumentParser(description="Visualize ground-truth data") parser.add_argument( "--source", choices=["annotation", "dataloader"], required=True, help="visualize the annotations or the data loader (with pre-processing)", ) parser.add_argument("--config-file", metavar="FILE", help="path to config file") parser.add_argument("--output-dir", default="./", help="path to output directory") parser.add_argument("--show", action="store_true", help="show output in a window") parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER, ) return parser.parse_args(in_args) if __name__ == "__main__": args = parse_args() logger = setup_logger() logger.info("Arguments: " + str(args)) cfg = setup(args) dirname = args.output_dir os.makedirs(dirname, exist_ok=True) if isinstance(cfg.dataloader.train.dataset.names, str): names = [cfg.dataloader.train.dataset.names] else: names = cfg.dataloader.train.dataset.names metadata = MetadataCatalog.get(names[0]) def output(vis, fname): if args.show: print(fname) cv2.imshow("window", vis.get_image()[:, :, ::-1]) cv2.waitKey() else: filepath = os.path.join(dirname, fname) print("Saving to {} ...".format(filepath)) vis.save(filepath) scale = 1.0 if args.source == "dataloader": train_data_loader = instantiate(cfg.dataloader.train) for batch in train_data_loader: for per_image in batch: # Pytorch tensor is in (C, H, W) format img = per_image["image"].permute(1, 2, 0).cpu().detach().numpy() img = utils.convert_image_to_rgb(img, cfg.dataloader.train.mapper.img_format) visualizer = Visualizer(img, metadata=metadata, scale=scale) target_fields = per_image["instances"].get_fields() labels = [metadata.thing_classes[i] for i in target_fields["gt_classes"]] vis = visualizer.overlay_instances( labels=labels, boxes=target_fields.get("gt_boxes", None), masks=target_fields.get("gt_masks", None), keypoints=target_fields.get("gt_keypoints", None), ) output(vis, str(per_image["image_id"]) + ".jpg") else: dicts = list(chain.from_iterable([DatasetCatalog.get(k) for k in names])) for dic in tqdm.tqdm(dicts): img = utils.read_image(dic["file_name"], "RGB") visualizer = Visualizer(img, metadata=metadata, scale=scale) vis = visualizer.draw_dataset_dict(dic) output(vis, os.path.basename(dic["file_name"]))