# Copyright (c) Facebook, Inc. and its affiliates. from __future__ import division from typing import Any, List, Tuple import torch from torch import device from detectron2.layers.wrappers import shapes_to_tensor class ImageList(object): """ Structure that holds a list of images (of possibly varying sizes) as a single tensor. This works by padding the images to the same size. The original sizes of each image is stored in `image_sizes`. Attributes: image_sizes (list[tuple[int, int]]): each tuple is (h, w). During tracing, it becomes list[Tensor] instead. """ def __init__(self, tensor: torch.Tensor, image_sizes: List[Tuple[int, int]], padding_mask: torch.Tensor = None): """ Arguments: tensor (Tensor): of shape (N, H, W) or (N, C_1, ..., C_K, H, W) where K >= 1 image_sizes (list[tuple[int, int]]): Each tuple is (h, w). It can be smaller than (H, W) due to padding. """ self.tensor = tensor self.image_sizes = image_sizes self.padding_mask = padding_mask def __len__(self) -> int: return len(self.image_sizes) def __getitem__(self, idx) -> torch.Tensor: """ Access the individual image in its original size. Args: idx: int or slice Returns: Tensor: an image of shape (H, W) or (C_1, ..., C_K, H, W) where K >= 1 """ size = self.image_sizes[idx] return self.tensor[idx, ..., : size[0], : size[1]] @torch.jit.unused def to(self, *args: Any, **kwargs: Any) -> "ImageList": cast_tensor = self.tensor.to(*args, **kwargs) return ImageList( cast_tensor, self.image_sizes, padding_mask=(self.padding_mask.to(*args, **kwargs) if self.padding_mask is not None else None), ) @property def device(self) -> device: return self.tensor.device @staticmethod def from_tensors(tensors: List[torch.Tensor], size_divisibility: int = 0, pad_value: float = 0.0) -> "ImageList": """ Args: tensors: a tuple or list of `torch.Tensor`, each of shape (Hi, Wi) or (C_1, ..., C_K, Hi, Wi) where K >= 1. The Tensors will be padded to the same shape with `pad_value`. size_divisibility (int): If `size_divisibility > 0`, add padding to ensure the common height and width is divisible by `size_divisibility`. This depends on the model and many models need a divisibility of 32. pad_value (float): value to pad Returns: an `ImageList`. """ assert len(tensors) > 0 assert isinstance(tensors, (tuple, list)) for t in tensors: assert isinstance(t, torch.Tensor), type(t) assert t.shape[:-2] == tensors[0].shape[:-2], t.shape image_sizes = [(im.shape[-2], im.shape[-1]) for im in tensors] image_sizes_tensor = [shapes_to_tensor(x) for x in image_sizes] max_size = torch.stack(image_sizes_tensor).max(0).values if size_divisibility > 1: stride = size_divisibility # the last two dims are H,W, both subject to divisibility requirement max_size = (max_size + (stride - 1)).div(stride, rounding_mode="floor") * stride # handle weirdness of scripting and tracing ... if torch.jit.is_scripting(): max_size: List[int] = max_size.to(dtype=torch.long).tolist() else: if torch.jit.is_tracing(): image_sizes = image_sizes_tensor # max_size can be a tensor in tracing mode, therefore convert to list batch_shape = [len(tensors)] + list(tensors[0].shape[:-2]) + list(max_size) batched_imgs = tensors[0].new_full(batch_shape, pad_value) batched_masks = tensors[0].new_full([len(tensors)] + list(max_size), 1.0).bool() for img, pad_img, m in zip(tensors, batched_imgs, batched_masks): pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img) m[: img.shape[-2], : img.shape[-1]] = False return ImageList(batched_imgs.contiguous(), image_sizes, batched_masks)