import copy import logging import random import numpy as np from typing import List, Union import torch from detectron2.config import configurable from detectron2.structures import ( BitMasks, Boxes, BoxMode, Instances, ) from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data import MetadataCatalog from .augmentation import build_augmentation from transformers import BertTokenizer, RobertaTokenizerFast import spacy __all__ = ["MeViSDatasetMapper"] def filter_empty_instances(instances, by_box=True, by_mask=True, box_threshold=1e-5): """ Filter out empty instances in an `Instances` object. Args: instances (Instances): by_box (bool): whether to filter out instances with empty boxes by_mask (bool): whether to filter out instances with empty masks box_threshold (float): minimum width and height to be considered non-empty Returns: Instances: the filtered instances. """ assert by_box or by_mask r = [] if by_box: r.append(instances.gt_boxes.nonempty(threshold=box_threshold)) if instances.has("gt_masks") and by_mask: r.append(instances.gt_masks.nonempty()) r.append(instances.gt_classes != -1) if not r: return instances m = r[0] for x in r[1:]: m = m & x instances.gt_ids[~m] = -1 return instances def _get_dummy_anno(): return { "iscrowd": 0, "category_id": -1, "id": -1, "bbox": np.array([0, 0, 0, 0]), "bbox_mode": BoxMode.XYXY_ABS, "segmentation": [np.array([0.0] * 6)] } class MeViSDatasetMapper: """ A callable which takes a dataset dict in YouTube-VIS Dataset format, and map it into a format used by the model. """ @configurable def __init__( self, is_train: bool, is_tgt: bool, *, augmentations: List[Union[T.Augmentation, T.Transform]], image_format: str, use_instance_mask: bool = False, sampling_frame_num: int = 2, sampling_frame_range: int = 5, sampling_frame_shuffle: bool = False, num_classes: int = 40, src_dataset_name: str = "", tgt_dataset_name: str = "", ): """ NOTE: this interface is experimental. Args: is_train: whether it's used in training or inference augmentations: a list of augmentations or deterministic transforms to apply image_format: an image format supported by :func:`detection_utils.read_image`. use_instance_mask: whether to process instance segmentation annotations, if available """ # fmt: off self.is_train = is_train self.is_tgt = is_tgt self.augmentations = T.AugmentationList(augmentations) self.image_format = image_format self.use_instance_mask = use_instance_mask self.sampling_frame_num = sampling_frame_num self.sampling_frame_range = sampling_frame_range self.sampling_frame_shuffle = sampling_frame_shuffle self.num_classes = num_classes if not is_tgt: self.src_metadata = MetadataCatalog.get(src_dataset_name) self.tgt_metadata = MetadataCatalog.get(tgt_dataset_name) if tgt_dataset_name.startswith("ytvis_2019"): src2tgt = OVIS_TO_YTVIS_2019 elif tgt_dataset_name.startswith("ytvis_2021"): src2tgt = OVIS_TO_YTVIS_2021 elif tgt_dataset_name.startswith("ovis"): if src_dataset_name.startswith("ytvis_2019"): src2tgt = YTVIS_2019_TO_OVIS elif src_dataset_name.startswith("ytvis_2021"): src2tgt = YTVIS_2021_TO_OVIS else: raise NotImplementedError else: raise NotImplementedError self.src2tgt = {} for k, v in src2tgt.items(): self.src2tgt[ self.src_metadata.thing_dataset_id_to_contiguous_id[k] ] = self.tgt_metadata.thing_dataset_id_to_contiguous_id[v] # fmt: on logger = logging.getLogger(__name__) mode = "training" if is_train else "inference" logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}") self.max_tokens = 40 # self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") self.tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base') self.nlp = spacy.load('en_core_web_sm') @classmethod def from_config(cls, cfg, is_train: bool = True, is_tgt: bool = True): augs = build_augmentation(cfg, is_train) sampling_frame_num = cfg.INPUT.SAMPLING_FRAME_NUM sampling_frame_range = cfg.INPUT.SAMPLING_FRAME_RANGE sampling_frame_shuffle = cfg.INPUT.SAMPLING_FRAME_SHUFFLE ret = { "is_train": is_train, "is_tgt": is_tgt, "augmentations": augs, "image_format": cfg.INPUT.FORMAT, "use_instance_mask": cfg.MODEL.MASK_ON, "sampling_frame_num": sampling_frame_num, "sampling_frame_range": sampling_frame_range, "sampling_frame_shuffle": sampling_frame_shuffle, "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, "tgt_dataset_name": cfg.DATASETS.TRAIN[-1], } return ret @staticmethod def _merge_masks(x): return x.sum(dim=0, keepdim=True).clamp(max=1) def __call__(self, dataset): instance_check = False while not instance_check: dataset_dict = copy.deepcopy(dataset) # it will be modified by code below video_annos = dataset_dict.pop("annotations", None) file_names = dataset_dict.pop("file_names", None) """ Args: dataset_dict (dict): Metadata of one video, in YTVIS Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ # TODO consider examining below deepcopy as it costs huge amount of computations. video_length = dataset_dict["length"] if self.is_train: ref_frame = random.randrange(video_length) start_idx = max(0, ref_frame - self.sampling_frame_range) end_idx = min(video_length, ref_frame + self.sampling_frame_range + 1) available_frames = list(range(start_idx, ref_frame)) + list(range(ref_frame + 1, end_idx)) population_size = len(set(available_frames)) if population_size > self.sampling_frame_num - 1: replace = False else: replace = True selected_idx = np.random.choice( np.array(list(range(start_idx, ref_frame)) + list(range(ref_frame + 1, end_idx))), self.sampling_frame_num - 1, replace=replace ) selected_idx = selected_idx.tolist() + [ref_frame] selected_idx = sorted(selected_idx) if self.sampling_frame_shuffle: random.shuffle(selected_idx) else: selected_idx = range(video_length) if self.is_train: _ids = set() for frame_idx in selected_idx: _ids.update([anno["id"] for anno in video_annos[frame_idx]]) ids = dict() for i, _id in enumerate(_ids): ids[_id] = i dataset_dict["video_len"] = len(video_annos) dataset_dict["frame_idx"] = list(selected_idx) dataset_dict["image"] = [] dataset_dict["instances"] = [] dataset_dict["gt_masks_merge"] = [] dataset_dict["file_names"] = [] valid = [] for frame_idx in selected_idx: dataset_dict["file_names"].append(file_names[frame_idx]) # Read image image = utils.read_image(file_names[frame_idx], format=self.image_format) height, width = image.shape[:2] dataset_dict["height"] = height dataset_dict["width"] = width aug_input = T.AugInput(image) transforms = self.augmentations(aug_input) image = aug_input.image image_shape = image.shape[:2] # h, w # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"].append(torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))) if (video_annos is None) or (not self.is_train): continue # NOTE copy() is to prevent annotations getting changed from applying augmentations _frame_annos = [] for anno in video_annos[frame_idx]: _anno = {} for k, v in anno.items(): _anno[k] = copy.deepcopy(v) _frame_annos.append(_anno) # USER: Implement additional transformations if you have other types of data annos = [ utils.transform_instance_annotations(obj, transforms, image_shape) for obj in _frame_annos if obj.get("iscrowd", 0) == 0 ] sorted_annos = [_get_dummy_anno() for _ in range(len(ids))] for _anno in annos: idx = ids[_anno["id"]] sorted_annos[idx] = _anno _gt_ids = [_anno["id"] for _anno in sorted_annos] instances = utils.annotations_to_instances(sorted_annos, image_shape, mask_format="bitmask") if not self.is_tgt: instances.gt_classes = torch.tensor( [self.src2tgt[c] if c in self.src2tgt else -1 for c in instances.gt_classes.tolist()] ) instances.gt_ids = torch.tensor(_gt_ids) # if instances.has("gt_masks"): # instances.gt_boxes = instances.gt_masks.get_bounding_boxes() # instances = filter_empty_instances(instances) if not instances.has("gt_masks"): instances.gt_masks = BitMasks(torch.empty((0, *image_shape))) instances.gt_boxes = instances.gt_masks.get_bounding_boxes() instances = filter_empty_instances(instances) merged_masks = self._merge_masks(instances.gt_masks.tensor) dataset_dict['gt_masks_merge'].append(merged_masks) if (merged_masks > 0).any(): valid.append(1) else: valid.append(0) dataset_dict["instances"].append(instances) if torch.any(torch.tensor(valid) == 1) or not self.is_train: # at leatst one instance instance_check = True else: instance_check = False sentence_raw = dataset_dict['sentence'] attention_mask = [0] * self.max_tokens padded_input_ids = [0] * self.max_tokens input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True) input_ids = input_ids[:self.max_tokens] padded_input_ids[:len(input_ids)] = input_ids attention_mask[:len(input_ids)] = [1] * len(input_ids) dataset_dict['lang_tokens'] = torch.tensor(padded_input_ids).unsqueeze(0) dataset_dict['lang_mask'] = torch.tensor(attention_mask).unsqueeze(0) dataset_dict['dataset_name'] = 'mevis' return dataset_dict