Spaces:
Runtime error
Runtime error
| from typing import List, Union | |
| import torch | |
| import torchvision | |
| class NestedTensor(object): | |
| """Define a NestedTensor class | |
| Args: | |
| tensors (torch.Tensor): Tensor with shape [batch, C, H, W] or [C, H, W] | |
| mask (Union[torch.Tensor, str]): mask with shape [batch, H, W] or [H, W]. If mask | |
| is 'auto', it will be generated automatically by summing the tensor along | |
| the channel dimension. Mask is used to indicate the padding area. | |
| """ | |
| def __init__( | |
| self, tensors: torch.Tensor, mask: Union[torch.Tensor, str] = "auto" | |
| ) -> None: | |
| self.tensors = tensors | |
| self.mask = mask | |
| if mask == "auto": | |
| self.mask = torch.zeros_like(tensors).to(tensors.device) | |
| if self.mask.dim() == 3: | |
| self.mask = self.mask.sum(0).to(bool) | |
| elif self.mask.dim() == 4: | |
| self.mask = self.mask.sum(1).to(bool) | |
| else: | |
| raise ValueError( | |
| "tensors dim must be 3 or 4 but {}({})".format( | |
| self.tensors.dim(), self.tensors.shape | |
| ) | |
| ) | |
| def imgsize(self) -> List[torch.Tensor]: | |
| """get the img size of the tensor | |
| Returns: | |
| list[torch.Tensor]: list of tensor with shape [2] which is [H, W] | |
| """ | |
| res = [] | |
| for i in range(self.tensors.shape[0]): | |
| mask = self.mask[i] | |
| maxH = (~mask).sum(0).max() | |
| maxW = (~mask).sum(1).max() | |
| res.append(torch.Tensor([maxH, maxW])) | |
| return res | |
| def to(self, device: torch.device): | |
| """Move tensors and mask to the given device | |
| Args: | |
| device (torch.device): device to move | |
| Returns: | |
| NestedTensor: moved NestedTensor | |
| """ | |
| cast_tensor = self.tensors.to(device) | |
| mask = self.mask | |
| if mask is not None: | |
| assert mask is not None | |
| cast_mask = mask.to(device) | |
| else: | |
| cast_mask = None | |
| return NestedTensor(cast_tensor, cast_mask) | |
| def to_img_list_single( | |
| self, tensor: torch.Tensor, mask: torch.Tensor | |
| ) -> torch.Tensor: | |
| """remove the padding for one image | |
| Args: | |
| tensor (torch.Tensor): tensor with shape [C, H, W] | |
| mask (torch.Tensor): mask with shape [H, W] | |
| Returns: | |
| torch.Tensor: tensor with shape [C, maxH, maxW] | |
| """ | |
| assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format( | |
| tensor.dim() | |
| ) | |
| maxH = (~mask).sum(0).max() | |
| maxW = (~mask).sum(1).max() | |
| img = tensor[:, :maxH, :maxW] | |
| return img | |
| def to_img_list(self) -> List[torch.Tensor]: | |
| """remove the padding and convert to img list | |
| Returns: | |
| list[torch.Tensor]: list of tensor with shape [C, maxH, maxW] | |
| """ | |
| if self.tensors.dim() == 3: | |
| return self.to_img_list_single(self.tensors, self.mask) | |
| else: | |
| res = [] | |
| for i in range(self.tensors.shape[0]): | |
| tensor_i = self.tensors[i] | |
| mask_i = self.mask[i] | |
| res.append(self.to_img_list_single(tensor_i, mask_i)) | |
| return res | |
| def device(self): | |
| return self.tensors.device | |
| def decompose(self): | |
| return self.tensors, self.mask | |
| def __repr__(self): | |
| return str(self.tensors) | |
| def shape(self): | |
| return {"tensors.shape": self.tensors.shape, "mask.shape": self.mask.shape} | |
| def _max_by_axis(the_list): | |
| # type: (List[List[int]]) -> List[int] | |
| maxes = the_list[0] | |
| for sublist in the_list[1:]: | |
| for index, item in enumerate(sublist): | |
| maxes[index] = max(maxes[index], item) | |
| return maxes | |
| def nested_tensor_from_tensor_list( | |
| tensor_list: List[torch.Tensor], fixed_img_size=None | |
| ): | |
| if fixed_img_size is not None: | |
| if isinstance(fixed_img_size, (list, tuple)): | |
| assert ( | |
| len(fixed_img_size) == 2 | |
| ), "image size should be a tuple or list with two elements" | |
| elif isinstance(fixed_img_size, int): | |
| fixed_img_size = [fixed_img_size, fixed_img_size] | |
| if tensor_list[0].ndim == 3: | |
| if torchvision._is_tracing(): | |
| # nested_tensor_from_tensor_list() does not export well to ONNX | |
| # call _onnx_nested_tensor_from_tensor_list() instead | |
| return _onnx_nested_tensor_from_tensor_list(tensor_list) | |
| # TODO make it support different-sized images | |
| max_size = _max_by_axis([list(img.shape) for img in tensor_list]) | |
| if fixed_img_size is not None: | |
| c, orig_h, orig_w = max_size | |
| assert ( | |
| orig_h <= fixed_img_size[0] and orig_w <= fixed_img_size[1] | |
| ), f"{orig_h} {orig_w} the fixed output image size should be larger than original image" | |
| max_size = [c, fixed_img_size[0], fixed_img_size[1]] | |
| # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) | |
| batch_shape = [len(tensor_list)] + max_size | |
| b, c, h, w = batch_shape | |
| dtype = tensor_list[0].dtype | |
| device = tensor_list[0].device | |
| tensor = torch.zeros(batch_shape, dtype=dtype, device=device) | |
| mask = torch.ones((b, h, w), dtype=torch.bool, device=device) | |
| for img, pad_img, m in zip(tensor_list, tensor, mask): | |
| pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) | |
| m[: img.shape[1], : img.shape[2]] = False | |
| else: | |
| raise ValueError("not supported") | |
| return NestedTensor(tensor, mask) | |
| def _onnx_nested_tensor_from_tensor_list( | |
| tensor_list: List[torch.Tensor], | |
| ) -> NestedTensor: | |
| max_size = [] | |
| for i in range(tensor_list[0].dim()): | |
| max_size_i = torch.max( | |
| torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) | |
| ).to(torch.int64) | |
| max_size.append(max_size_i) | |
| max_size = tuple(max_size) | |
| padded_imgs = [] | |
| padded_masks = [] | |
| for img in tensor_list: | |
| padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] | |
| padded_img = torch.nn.functional.pad( | |
| img, (0, padding[2], 0, padding[1], 0, padding[0]) | |
| ) | |
| padded_imgs.append(padded_img) | |
| m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) | |
| padded_mask = torch.nn.functional.pad( | |
| m, (0, padding[2], 0, padding[1]), "constant", 1 | |
| ) | |
| padded_masks.append(padded_mask.to(torch.bool)) | |
| tensor = torch.stack(padded_imgs) | |
| mask = torch.stack(padded_masks) | |
| return NestedTensor(tensor, mask=mask) | |