Spaces:
Running
Running
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| # from https://github.com/facebookresearch/detr/blob/main/d2/detr/dataset_mapper.py | |
| import copy | |
| import logging | |
| from os import path | |
| import numpy as np | |
| import torch | |
| from detectron2.data import detection_utils as utils | |
| from detectron2.data import transforms as T | |
| import json | |
| import pickle | |
| from detectron2.structures import ( | |
| BitMasks, | |
| Boxes, | |
| BoxMode, | |
| Instances, | |
| Keypoints, | |
| PolygonMasks, | |
| RotatedBoxes, | |
| polygons_to_bitmask, | |
| ) | |
| __all__ = ["DetrDatasetMapper"] | |
| def build_transform_gen(cfg, is_train): | |
| """ | |
| Create a list of :class:`TransformGen` from config. | |
| Returns: | |
| list[TransformGen] | |
| """ | |
| if is_train: | |
| min_size = cfg.INPUT.MIN_SIZE_TRAIN | |
| max_size = cfg.INPUT.MAX_SIZE_TRAIN | |
| sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING | |
| else: | |
| min_size = cfg.INPUT.MIN_SIZE_TEST | |
| max_size = cfg.INPUT.MAX_SIZE_TEST | |
| sample_style = "choice" | |
| if sample_style == "range": | |
| assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size)) | |
| logger = logging.getLogger(__name__) | |
| tfm_gens = [] | |
| # if is_train: | |
| # tfm_gens.append(T.RandomFlip()) | |
| tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style)) | |
| if is_train: | |
| logger.info("TransformGens used in training: " + str(tfm_gens)) | |
| return tfm_gens | |
| def build_transform_gen_w(cfg, is_train): | |
| """ | |
| Create a list of :class:`TransformGen` from config. | |
| Returns: | |
| list[TransformGen] | |
| """ | |
| if is_train: | |
| min_size = cfg.INPUT.MIN_SIZE_TRAIN | |
| max_size = 800 | |
| sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING | |
| else: | |
| min_size = cfg.INPUT.MIN_SIZE_TEST | |
| max_size = cfg.INPUT.MAX_SIZE_TEST | |
| sample_style = "choice" | |
| if sample_style == "range": | |
| assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size)) | |
| logger = logging.getLogger(__name__) | |
| tfm_gens = [] | |
| # if is_train: | |
| # tfm_gens.append(T.RandomFlip()) | |
| tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style)) | |
| if is_train: | |
| logger.info("TransformGens used in training: " + str(tfm_gens)) | |
| return tfm_gens | |
| class DetrDatasetMapper: | |
| """ | |
| A callable which takes a dataset dict in Detectron2 Dataset format, | |
| and map it into a format used by DETR. | |
| The callable currently does the following: | |
| 1. Read the image from "file_name" | |
| 2. Applies geometric transforms to the image and annotation | |
| 3. Find and applies suitable cropping to the image and annotation | |
| 4. Prepare image and annotation to Tensors | |
| """ | |
| def __init__(self, cfg, is_train=True): | |
| if cfg.INPUT.CROP.ENABLED and is_train: | |
| self.crop_gen = [ | |
| T.ResizeShortestEdge([400, 500, 600], sample_style="choice"), | |
| T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE), | |
| ] | |
| else: | |
| self.crop_gen = None | |
| self.mask_on = cfg.MODEL.MASK_ON | |
| self.tfm_gens = build_transform_gen(cfg, is_train) | |
| self.tfm_gens_w = build_transform_gen_w(cfg, is_train) | |
| logging.getLogger(__name__).info( | |
| "Full TransformGens used in training: {}, crop: {}".format(str(self.tfm_gens), str(self.crop_gen)) | |
| ) | |
| self.img_format = cfg.INPUT.FORMAT | |
| self.is_train = is_train | |
| self.cfg = cfg | |
| logger = logging.getLogger("detectron2") | |
| def __call__(self, dataset_dict): | |
| """ | |
| Args: | |
| dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. | |
| Returns: | |
| dict: a format that builtin models in detectron2 accept | |
| """ | |
| dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below | |
| image = utils.read_image(dataset_dict["file_name"], format=self.img_format) | |
| utils.check_image_size(dataset_dict, image) | |
| word_grid_path = dataset_dict["file_name"].replace("images", "word_grids").replace(".jpg", ".pkl") | |
| if path.exists(word_grid_path): | |
| with open(word_grid_path, "rb") as f: | |
| sample_inputs = pickle.load(f) | |
| input_ids = sample_inputs["input_ids"] | |
| bbox_subword_list = sample_inputs["bbox_subword_list"] | |
| else: | |
| input_ids = [] | |
| bbox_subword_list = [] | |
| print(f"No word grid pkl in: {word_grid_path}") | |
| image_shape_ori = image.shape[:2] # h, w | |
| if self.crop_gen is None: | |
| if image_shape_ori[0] > image_shape_ori[1]: | |
| image, transforms = T.apply_transform_gens(self.tfm_gens, image) | |
| else: | |
| image, transforms = T.apply_transform_gens(self.tfm_gens_w, image) | |
| else: | |
| if np.random.rand() > 0.5: | |
| if image_shape_ori[0] > image_shape_ori[1]: | |
| image, transforms = T.apply_transform_gens(self.tfm_gens, image) | |
| else: | |
| image, transforms = T.apply_transform_gens(self.tfm_gens_w, image) | |
| else: | |
| image, transforms = T.apply_transform_gens( | |
| self.tfm_gens_w[:-1] + self.crop_gen + self.tfm_gens_w[-1:], image | |
| ) | |
| image_shape = image.shape[:2] # h, w | |
| # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, | |
| # but not efficient on large generic data structures due to the use of pickle & mp.Queue. | |
| # Therefore it's important to use torch.Tensor. | |
| dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) | |
| ## 产出 text grid | |
| bbox = [] | |
| for bbox_per_subword in bbox_subword_list: | |
| text_word = {} | |
| text_word["bbox"] = bbox_per_subword.tolist() | |
| text_word["bbox_mode"] = BoxMode.XYWH_ABS | |
| utils.transform_instance_annotations(text_word, transforms, image_shape) | |
| bbox.append(text_word["bbox"]) | |
| dataset_dict["input_ids"] = input_ids | |
| dataset_dict["bbox"] = bbox | |
| if not self.is_train: | |
| # USER: Modify this if you want to keep them for some reason. | |
| dataset_dict.pop("annotations", None) | |
| return dataset_dict | |
| if "annotations" in dataset_dict: | |
| # USER: Modify this if you want to keep them for some reason. | |
| for anno in dataset_dict["annotations"]: | |
| if not self.mask_on: | |
| anno.pop("segmentation", None) | |
| anno.pop("keypoints", None) | |
| # USER: Implement additional transformations if you have other types of data | |
| annos = [ | |
| utils.transform_instance_annotations(obj, transforms, image_shape) | |
| for obj in dataset_dict.pop("annotations") | |
| if obj.get("iscrowd", 0) == 0 | |
| ] | |
| instances = utils.annotations_to_instances(annos, image_shape) | |
| dataset_dict["instances"] = utils.filter_empty_instances(instances) | |
| return dataset_dict | |