| | |
| | from __future__ import division |
| | from typing import Any, Dict, List, Optional, Tuple |
| | import torch |
| | from torch import device |
| | from torch.nn import functional as F |
| |
|
| | from detectron2.layers.wrappers import move_device_like, shapes_to_tensor |
| |
|
| |
|
| | class ImageList: |
| | """ |
| | Structure that holds a list of images (of possibly |
| | varying sizes) as a single tensor. |
| | This works by padding the images to the same size. |
| | The original sizes of each image is stored in `image_sizes`. |
| | |
| | Attributes: |
| | image_sizes (list[tuple[int, int]]): each tuple is (h, w). |
| | During tracing, it becomes list[Tensor] instead. |
| | """ |
| |
|
| | def __init__(self, tensor: torch.Tensor, image_sizes: List[Tuple[int, int]]): |
| | """ |
| | Arguments: |
| | tensor (Tensor): of shape (N, H, W) or (N, C_1, ..., C_K, H, W) where K >= 1 |
| | image_sizes (list[tuple[int, int]]): Each tuple is (h, w). It can |
| | be smaller than (H, W) due to padding. |
| | """ |
| | self.tensor = tensor |
| | self.image_sizes = image_sizes |
| |
|
| | def __len__(self) -> int: |
| | return len(self.image_sizes) |
| |
|
| | def __getitem__(self, idx) -> torch.Tensor: |
| | """ |
| | Access the individual image in its original size. |
| | |
| | Args: |
| | idx: int or slice |
| | |
| | Returns: |
| | Tensor: an image of shape (H, W) or (C_1, ..., C_K, H, W) where K >= 1 |
| | """ |
| | size = self.image_sizes[idx] |
| | return self.tensor[idx, ..., : size[0], : size[1]] |
| |
|
| | @torch.jit.unused |
| | def to(self, *args: Any, **kwargs: Any) -> "ImageList": |
| | cast_tensor = self.tensor.to(*args, **kwargs) |
| | return ImageList(cast_tensor, self.image_sizes) |
| |
|
| | @property |
| | def device(self) -> device: |
| | return self.tensor.device |
| |
|
| | @staticmethod |
| | def from_tensors( |
| | tensors: List[torch.Tensor], |
| | size_divisibility: int = 0, |
| | pad_value: float = 0.0, |
| | padding_constraints: Optional[Dict[str, int]] = None, |
| | ) -> "ImageList": |
| | """ |
| | Args: |
| | tensors: a tuple or list of `torch.Tensor`, each of shape (Hi, Wi) or |
| | (C_1, ..., C_K, Hi, Wi) where K >= 1. The Tensors will be padded |
| | to the same shape with `pad_value`. |
| | size_divisibility (int): If `size_divisibility > 0`, add padding to ensure |
| | the common height and width is divisible by `size_divisibility`. |
| | This depends on the model and many models need a divisibility of 32. |
| | pad_value (float): value to pad. |
| | padding_constraints (optional[Dict]): If given, it would follow the format as |
| | {"size_divisibility": int, "square_size": int}, where `size_divisibility` will |
| | overwrite the above one if presented and `square_size` indicates the |
| | square padding size if `square_size` > 0. |
| | Returns: |
| | an `ImageList`. |
| | """ |
| | assert len(tensors) > 0 |
| | assert isinstance(tensors, (tuple, list)) |
| | for t in tensors: |
| | assert isinstance(t, torch.Tensor), type(t) |
| | assert t.shape[:-2] == tensors[0].shape[:-2], t.shape |
| |
|
| | image_sizes = [(im.shape[-2], im.shape[-1]) for im in tensors] |
| | image_sizes_tensor = [shapes_to_tensor(x) for x in image_sizes] |
| | max_size = torch.stack(image_sizes_tensor).max(0).values |
| |
|
| | if padding_constraints is not None: |
| | square_size = padding_constraints.get("square_size", 0) |
| | if square_size > 0: |
| | |
| | max_size[0] = max_size[1] = square_size |
| | if "size_divisibility" in padding_constraints: |
| | size_divisibility = padding_constraints["size_divisibility"] |
| | if size_divisibility > 1: |
| | stride = size_divisibility |
| | |
| | max_size = (max_size + (stride - 1)).div(stride, rounding_mode="floor") * stride |
| |
|
| | |
| | if torch.jit.is_scripting(): |
| | max_size: List[int] = max_size.to(dtype=torch.long).tolist() |
| | else: |
| | if torch.jit.is_tracing(): |
| | image_sizes = image_sizes_tensor |
| |
|
| | if len(tensors) == 1: |
| | |
| | |
| | image_size = image_sizes[0] |
| | padding_size = [0, max_size[-1] - image_size[1], 0, max_size[-2] - image_size[0]] |
| | batched_imgs = F.pad(tensors[0], padding_size, value=pad_value).unsqueeze_(0) |
| | else: |
| | |
| | batch_shape = [len(tensors)] + list(tensors[0].shape[:-2]) + list(max_size) |
| | device = ( |
| | None if torch.jit.is_scripting() else ("cpu" if torch.jit.is_tracing() else None) |
| | ) |
| | batched_imgs = tensors[0].new_full(batch_shape, pad_value, device=device) |
| | batched_imgs = move_device_like(batched_imgs, tensors[0]) |
| | for i, img in enumerate(tensors): |
| | |
| | |
| | batched_imgs[i, ..., : img.shape[-2], : img.shape[-1]].copy_(img) |
| |
|
| | return ImageList(batched_imgs.contiguous(), image_sizes) |
| |
|