|
|
import copy |
|
|
import logging |
|
|
import random |
|
|
import numpy as np |
|
|
from typing import List, Union |
|
|
import torch |
|
|
|
|
|
from detectron2.config import configurable |
|
|
from detectron2.structures import ( |
|
|
BitMasks, |
|
|
Boxes, |
|
|
BoxMode, |
|
|
Instances, |
|
|
) |
|
|
|
|
|
from detectron2.data import detection_utils as utils |
|
|
from detectron2.data import transforms as T |
|
|
from detectron2.data import MetadataCatalog |
|
|
|
|
|
from .augmentation import build_augmentation |
|
|
from transformers import BertTokenizer, RobertaTokenizerFast |
|
|
import spacy |
|
|
|
|
|
__all__ = ["MeViSDatasetMapper"] |
|
|
|
|
|
|
|
|
def filter_empty_instances(instances, by_box=True, by_mask=True, box_threshold=1e-5): |
|
|
""" |
|
|
Filter out empty instances in an `Instances` object. |
|
|
|
|
|
Args: |
|
|
instances (Instances): |
|
|
by_box (bool): whether to filter out instances with empty boxes |
|
|
by_mask (bool): whether to filter out instances with empty masks |
|
|
box_threshold (float): minimum width and height to be considered non-empty |
|
|
|
|
|
Returns: |
|
|
Instances: the filtered instances. |
|
|
""" |
|
|
assert by_box or by_mask |
|
|
r = [] |
|
|
if by_box: |
|
|
r.append(instances.gt_boxes.nonempty(threshold=box_threshold)) |
|
|
if instances.has("gt_masks") and by_mask: |
|
|
r.append(instances.gt_masks.nonempty()) |
|
|
r.append(instances.gt_classes != -1) |
|
|
|
|
|
if not r: |
|
|
return instances |
|
|
m = r[0] |
|
|
for x in r[1:]: |
|
|
m = m & x |
|
|
|
|
|
instances.gt_ids[~m] = -1 |
|
|
return instances |
|
|
|
|
|
|
|
|
def _get_dummy_anno(): |
|
|
return { |
|
|
"iscrowd": 0, |
|
|
"category_id": -1, |
|
|
"id": -1, |
|
|
"bbox": np.array([0, 0, 0, 0]), |
|
|
"bbox_mode": BoxMode.XYXY_ABS, |
|
|
"segmentation": [np.array([0.0] * 6)] |
|
|
} |
|
|
|
|
|
|
|
|
class MeViSDatasetMapper: |
|
|
""" |
|
|
A callable which takes a dataset dict in YouTube-VIS Dataset format, |
|
|
and map it into a format used by the model. |
|
|
""" |
|
|
|
|
|
@configurable |
|
|
def __init__( |
|
|
self, |
|
|
is_train: bool, |
|
|
is_tgt: bool, |
|
|
*, |
|
|
augmentations: List[Union[T.Augmentation, T.Transform]], |
|
|
image_format: str, |
|
|
use_instance_mask: bool = False, |
|
|
sampling_frame_num: int = 2, |
|
|
sampling_frame_range: int = 5, |
|
|
sampling_frame_shuffle: bool = False, |
|
|
num_classes: int = 40, |
|
|
src_dataset_name: str = "", |
|
|
tgt_dataset_name: str = "", |
|
|
): |
|
|
""" |
|
|
NOTE: this interface is experimental. |
|
|
Args: |
|
|
is_train: whether it's used in training or inference |
|
|
augmentations: a list of augmentations or deterministic transforms to apply |
|
|
image_format: an image format supported by :func:`detection_utils.read_image`. |
|
|
use_instance_mask: whether to process instance segmentation annotations, if available |
|
|
""" |
|
|
|
|
|
self.is_train = is_train |
|
|
self.is_tgt = is_tgt |
|
|
self.augmentations = T.AugmentationList(augmentations) |
|
|
self.image_format = image_format |
|
|
self.use_instance_mask = use_instance_mask |
|
|
self.sampling_frame_num = sampling_frame_num |
|
|
self.sampling_frame_range = sampling_frame_range |
|
|
self.sampling_frame_shuffle = sampling_frame_shuffle |
|
|
self.num_classes = num_classes |
|
|
|
|
|
if not is_tgt: |
|
|
self.src_metadata = MetadataCatalog.get(src_dataset_name) |
|
|
self.tgt_metadata = MetadataCatalog.get(tgt_dataset_name) |
|
|
if tgt_dataset_name.startswith("ytvis_2019"): |
|
|
src2tgt = OVIS_TO_YTVIS_2019 |
|
|
elif tgt_dataset_name.startswith("ytvis_2021"): |
|
|
src2tgt = OVIS_TO_YTVIS_2021 |
|
|
elif tgt_dataset_name.startswith("ovis"): |
|
|
if src_dataset_name.startswith("ytvis_2019"): |
|
|
src2tgt = YTVIS_2019_TO_OVIS |
|
|
elif src_dataset_name.startswith("ytvis_2021"): |
|
|
src2tgt = YTVIS_2021_TO_OVIS |
|
|
else: |
|
|
raise NotImplementedError |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
self.src2tgt = {} |
|
|
for k, v in src2tgt.items(): |
|
|
self.src2tgt[ |
|
|
self.src_metadata.thing_dataset_id_to_contiguous_id[k] |
|
|
] = self.tgt_metadata.thing_dataset_id_to_contiguous_id[v] |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
mode = "training" if is_train else "inference" |
|
|
logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}") |
|
|
self.max_tokens = 40 |
|
|
|
|
|
self.tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base') |
|
|
self.nlp = spacy.load('en_core_web_sm') |
|
|
|
|
|
@classmethod |
|
|
def from_config(cls, cfg, is_train: bool = True, is_tgt: bool = True): |
|
|
augs = build_augmentation(cfg, is_train) |
|
|
|
|
|
sampling_frame_num = cfg.INPUT.SAMPLING_FRAME_NUM |
|
|
sampling_frame_range = cfg.INPUT.SAMPLING_FRAME_RANGE |
|
|
sampling_frame_shuffle = cfg.INPUT.SAMPLING_FRAME_SHUFFLE |
|
|
|
|
|
ret = { |
|
|
"is_train": is_train, |
|
|
"is_tgt": is_tgt, |
|
|
"augmentations": augs, |
|
|
"image_format": cfg.INPUT.FORMAT, |
|
|
"use_instance_mask": cfg.MODEL.MASK_ON, |
|
|
"sampling_frame_num": sampling_frame_num, |
|
|
"sampling_frame_range": sampling_frame_range, |
|
|
"sampling_frame_shuffle": sampling_frame_shuffle, |
|
|
"num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, |
|
|
"tgt_dataset_name": cfg.DATASETS.TRAIN[-1], |
|
|
} |
|
|
|
|
|
return ret |
|
|
|
|
|
@staticmethod |
|
|
def _merge_masks(x): |
|
|
return x.sum(dim=0, keepdim=True).clamp(max=1) |
|
|
|
|
|
def __call__(self, dataset): |
|
|
|
|
|
instance_check = False |
|
|
|
|
|
while not instance_check: |
|
|
dataset_dict = copy.deepcopy(dataset) |
|
|
video_annos = dataset_dict.pop("annotations", None) |
|
|
file_names = dataset_dict.pop("file_names", None) |
|
|
""" |
|
|
Args: |
|
|
dataset_dict (dict): Metadata of one video, in YTVIS Dataset format. |
|
|
|
|
|
Returns: |
|
|
dict: a format that builtin models in detectron2 accept |
|
|
""" |
|
|
|
|
|
video_length = dataset_dict["length"] |
|
|
if self.is_train: |
|
|
ref_frame = random.randrange(video_length) |
|
|
|
|
|
start_idx = max(0, ref_frame - self.sampling_frame_range) |
|
|
end_idx = min(video_length, ref_frame + self.sampling_frame_range + 1) |
|
|
available_frames = list(range(start_idx, ref_frame)) + list(range(ref_frame + 1, end_idx)) |
|
|
population_size = len(set(available_frames)) |
|
|
if population_size > self.sampling_frame_num - 1: |
|
|
replace = False |
|
|
else: |
|
|
replace = True |
|
|
selected_idx = np.random.choice( |
|
|
np.array(list(range(start_idx, ref_frame)) + list(range(ref_frame + 1, end_idx))), |
|
|
self.sampling_frame_num - 1, replace=replace |
|
|
) |
|
|
selected_idx = selected_idx.tolist() + [ref_frame] |
|
|
selected_idx = sorted(selected_idx) |
|
|
if self.sampling_frame_shuffle: |
|
|
random.shuffle(selected_idx) |
|
|
else: |
|
|
selected_idx = range(video_length) |
|
|
|
|
|
if self.is_train: |
|
|
_ids = set() |
|
|
for frame_idx in selected_idx: |
|
|
_ids.update([anno["id"] for anno in video_annos[frame_idx]]) |
|
|
ids = dict() |
|
|
for i, _id in enumerate(_ids): |
|
|
ids[_id] = i |
|
|
|
|
|
dataset_dict["video_len"] = len(video_annos) |
|
|
dataset_dict["frame_idx"] = list(selected_idx) |
|
|
dataset_dict["image"] = [] |
|
|
dataset_dict["instances"] = [] |
|
|
dataset_dict["gt_masks_merge"] = [] |
|
|
dataset_dict["file_names"] = [] |
|
|
valid = [] |
|
|
for frame_idx in selected_idx: |
|
|
dataset_dict["file_names"].append(file_names[frame_idx]) |
|
|
|
|
|
|
|
|
image = utils.read_image(file_names[frame_idx], format=self.image_format) |
|
|
height, width = image.shape[:2] |
|
|
dataset_dict["height"] = height |
|
|
dataset_dict["width"] = width |
|
|
|
|
|
aug_input = T.AugInput(image) |
|
|
transforms = self.augmentations(aug_input) |
|
|
image = aug_input.image |
|
|
|
|
|
image_shape = image.shape[:2] |
|
|
|
|
|
|
|
|
|
|
|
dataset_dict["image"].append(torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))) |
|
|
|
|
|
if (video_annos is None) or (not self.is_train): |
|
|
continue |
|
|
|
|
|
|
|
|
_frame_annos = [] |
|
|
for anno in video_annos[frame_idx]: |
|
|
_anno = {} |
|
|
for k, v in anno.items(): |
|
|
_anno[k] = copy.deepcopy(v) |
|
|
_frame_annos.append(_anno) |
|
|
|
|
|
|
|
|
annos = [ |
|
|
utils.transform_instance_annotations(obj, transforms, image_shape) |
|
|
for obj in _frame_annos |
|
|
if obj.get("iscrowd", 0) == 0 |
|
|
] |
|
|
sorted_annos = [_get_dummy_anno() for _ in range(len(ids))] |
|
|
|
|
|
for _anno in annos: |
|
|
idx = ids[_anno["id"]] |
|
|
sorted_annos[idx] = _anno |
|
|
_gt_ids = [_anno["id"] for _anno in sorted_annos] |
|
|
|
|
|
instances = utils.annotations_to_instances(sorted_annos, image_shape, mask_format="bitmask") |
|
|
if not self.is_tgt: |
|
|
instances.gt_classes = torch.tensor( |
|
|
[self.src2tgt[c] if c in self.src2tgt else -1 for c in instances.gt_classes.tolist()] |
|
|
) |
|
|
instances.gt_ids = torch.tensor(_gt_ids) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not instances.has("gt_masks"): |
|
|
instances.gt_masks = BitMasks(torch.empty((0, *image_shape))) |
|
|
instances.gt_boxes = instances.gt_masks.get_bounding_boxes() |
|
|
|
|
|
instances = filter_empty_instances(instances) |
|
|
merged_masks = self._merge_masks(instances.gt_masks.tensor) |
|
|
dataset_dict['gt_masks_merge'].append(merged_masks) |
|
|
if (merged_masks > 0).any(): |
|
|
valid.append(1) |
|
|
else: |
|
|
valid.append(0) |
|
|
|
|
|
dataset_dict["instances"].append(instances) |
|
|
|
|
|
if torch.any(torch.tensor(valid) == 1) or not self.is_train: |
|
|
instance_check = True |
|
|
else: |
|
|
instance_check = False |
|
|
sentence_raw = dataset_dict['sentence'] |
|
|
attention_mask = [0] * self.max_tokens |
|
|
padded_input_ids = [0] * self.max_tokens |
|
|
|
|
|
input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True) |
|
|
|
|
|
input_ids = input_ids[:self.max_tokens] |
|
|
padded_input_ids[:len(input_ids)] = input_ids |
|
|
attention_mask[:len(input_ids)] = [1] * len(input_ids) |
|
|
|
|
|
dataset_dict['lang_tokens'] = torch.tensor(padded_input_ids).unsqueeze(0) |
|
|
dataset_dict['lang_mask'] = torch.tensor(attention_mask).unsqueeze(0) |
|
|
dataset_dict['dataset_name'] = 'mevis' |
|
|
return dataset_dict |
|
|
|