| |
| import copy |
| import logging |
| import numpy as np |
| from typing import Callable, List, Union |
| import torch |
| from panopticapi.utils import rgb2id |
|
|
| from detectron2.config import configurable |
| from detectron2.data import MetadataCatalog |
| from detectron2.data import detection_utils as utils |
| from detectron2.data import transforms as T |
|
|
| from .target_generator import PanopticDeepLabTargetGenerator |
|
|
| __all__ = ["PanopticDeeplabDatasetMapper"] |
|
|
|
|
| class PanopticDeeplabDatasetMapper: |
| """ |
| The callable currently does the following: |
| |
| 1. Read the image from "file_name" and label from "pan_seg_file_name" |
| 2. Applies random scale, crop and flip transforms to image and label |
| 3. Prepare data to Tensor and generate training targets from label |
| """ |
|
|
| @configurable |
| def __init__( |
| self, |
| *, |
| augmentations: List[Union[T.Augmentation, T.Transform]], |
| image_format: str, |
| panoptic_target_generator: Callable, |
| ): |
| """ |
| NOTE: this interface is experimental. |
| |
| Args: |
| augmentations: a list of augmentations or deterministic transforms to apply |
| image_format: an image format supported by :func:`detection_utils.read_image`. |
| panoptic_target_generator: a callable that takes "panoptic_seg" and |
| "segments_info" to generate training targets for the model. |
| """ |
| |
| self.augmentations = T.AugmentationList(augmentations) |
| self.image_format = image_format |
| |
| logger = logging.getLogger(__name__) |
| logger.info("Augmentations used in training: " + str(augmentations)) |
|
|
| self.panoptic_target_generator = panoptic_target_generator |
|
|
| @classmethod |
| def from_config(cls, cfg): |
| augs = [ |
| T.ResizeShortestEdge( |
| cfg.INPUT.MIN_SIZE_TRAIN, |
| cfg.INPUT.MAX_SIZE_TRAIN, |
| cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING, |
| ) |
| ] |
| if cfg.INPUT.CROP.ENABLED: |
| augs.append(T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)) |
| augs.append(T.RandomFlip()) |
|
|
| |
| dataset_names = cfg.DATASETS.TRAIN |
| meta = MetadataCatalog.get(dataset_names[0]) |
| panoptic_target_generator = PanopticDeepLabTargetGenerator( |
| ignore_label=meta.ignore_label, |
| thing_ids=list(meta.thing_dataset_id_to_contiguous_id.values()), |
| sigma=cfg.INPUT.GAUSSIAN_SIGMA, |
| ignore_stuff_in_offset=cfg.INPUT.IGNORE_STUFF_IN_OFFSET, |
| small_instance_area=cfg.INPUT.SMALL_INSTANCE_AREA, |
| small_instance_weight=cfg.INPUT.SMALL_INSTANCE_WEIGHT, |
| ignore_crowd_in_semantic=cfg.INPUT.IGNORE_CROWD_IN_SEMANTIC, |
| ) |
|
|
| ret = { |
| "augmentations": augs, |
| "image_format": cfg.INPUT.FORMAT, |
| "panoptic_target_generator": panoptic_target_generator, |
| } |
| return ret |
|
|
| def __call__(self, dataset_dict): |
| """ |
| Args: |
| dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. |
| |
| Returns: |
| dict: a format that builtin models in detectron2 accept |
| """ |
| dataset_dict = copy.deepcopy(dataset_dict) |
| |
| image = utils.read_image(dataset_dict["file_name"], format=self.image_format) |
| utils.check_image_size(dataset_dict, image) |
| |
| pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB") |
|
|
| |
| aug_input = T.AugInput(image, sem_seg=pan_seg_gt) |
| _ = self.augmentations(aug_input) |
| image, pan_seg_gt = aug_input.image, aug_input.sem_seg |
|
|
| |
| |
| |
| dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) |
|
|
| |
| targets = self.panoptic_target_generator(rgb2id(pan_seg_gt), dataset_dict["segments_info"]) |
| dataset_dict.update(targets) |
|
|
| return dataset_dict |
|
|