PengLiu
push inference code
56ef371
raw
history blame
6.76 kB
from typing import List, Union
import torch
import torchvision
class NestedTensor(object):
"""Define a NestedTensor class
Args:
tensors (torch.Tensor): Tensor with shape [batch, C, H, W] or [C, H, W]
mask (Union[torch.Tensor, str]): mask with shape [batch, H, W] or [H, W]. If mask
is 'auto', it will be generated automatically by summing the tensor along
the channel dimension. Mask is used to indicate the padding area.
"""
def __init__(
self, tensors: torch.Tensor, mask: Union[torch.Tensor, str] = "auto"
) -> None:
self.tensors = tensors
self.mask = mask
if mask == "auto":
self.mask = torch.zeros_like(tensors).to(tensors.device)
if self.mask.dim() == 3:
self.mask = self.mask.sum(0).to(bool)
elif self.mask.dim() == 4:
self.mask = self.mask.sum(1).to(bool)
else:
raise ValueError(
"tensors dim must be 3 or 4 but {}({})".format(
self.tensors.dim(), self.tensors.shape
)
)
def imgsize(self) -> List[torch.Tensor]:
"""get the img size of the tensor
Returns:
list[torch.Tensor]: list of tensor with shape [2] which is [H, W]
"""
res = []
for i in range(self.tensors.shape[0]):
mask = self.mask[i]
maxH = (~mask).sum(0).max()
maxW = (~mask).sum(1).max()
res.append(torch.Tensor([maxH, maxW]))
return res
def to(self, device: torch.device):
"""Move tensors and mask to the given device
Args:
device (torch.device): device to move
Returns:
NestedTensor: moved NestedTensor
"""
cast_tensor = self.tensors.to(device)
mask = self.mask
if mask is not None:
assert mask is not None
cast_mask = mask.to(device)
else:
cast_mask = None
return NestedTensor(cast_tensor, cast_mask)
def to_img_list_single(
self, tensor: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
"""remove the padding for one image
Args:
tensor (torch.Tensor): tensor with shape [C, H, W]
mask (torch.Tensor): mask with shape [H, W]
Returns:
torch.Tensor: tensor with shape [C, maxH, maxW]
"""
assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format(
tensor.dim()
)
maxH = (~mask).sum(0).max()
maxW = (~mask).sum(1).max()
img = tensor[:, :maxH, :maxW]
return img
def to_img_list(self) -> List[torch.Tensor]:
"""remove the padding and convert to img list
Returns:
list[torch.Tensor]: list of tensor with shape [C, maxH, maxW]
"""
if self.tensors.dim() == 3:
return self.to_img_list_single(self.tensors, self.mask)
else:
res = []
for i in range(self.tensors.shape[0]):
tensor_i = self.tensors[i]
mask_i = self.mask[i]
res.append(self.to_img_list_single(tensor_i, mask_i))
return res
@property
def device(self):
return self.tensors.device
def decompose(self):
return self.tensors, self.mask
def __repr__(self):
return str(self.tensors)
@property
def shape(self):
return {"tensors.shape": self.tensors.shape, "mask.shape": self.mask.shape}
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
def nested_tensor_from_tensor_list(
tensor_list: List[torch.Tensor], fixed_img_size=None
):
if fixed_img_size is not None:
if isinstance(fixed_img_size, (list, tuple)):
assert (
len(fixed_img_size) == 2
), "image size should be a tuple or list with two elements"
elif isinstance(fixed_img_size, int):
fixed_img_size = [fixed_img_size, fixed_img_size]
if tensor_list[0].ndim == 3:
if torchvision._is_tracing():
# nested_tensor_from_tensor_list() does not export well to ONNX
# call _onnx_nested_tensor_from_tensor_list() instead
return _onnx_nested_tensor_from_tensor_list(tensor_list)
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
if fixed_img_size is not None:
c, orig_h, orig_w = max_size
assert (
orig_h <= fixed_img_size[0] and orig_w <= fixed_img_size[1]
), f"{orig_h} {orig_w} the fixed output image size should be larger than original image"
max_size = [c, fixed_img_size[0], fixed_img_size[1]]
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], : img.shape[2]] = False
else:
raise ValueError("not supported")
return NestedTensor(tensor, mask)
@torch.jit.unused
def _onnx_nested_tensor_from_tensor_list(
tensor_list: List[torch.Tensor],
) -> NestedTensor:
max_size = []
for i in range(tensor_list[0].dim()):
max_size_i = torch.max(
torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
).to(torch.int64)
max_size.append(max_size_i)
max_size = tuple(max_size)
padded_imgs = []
padded_masks = []
for img in tensor_list:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
padded_img = torch.nn.functional.pad(
img, (0, padding[2], 0, padding[1], 0, padding[0])
)
padded_imgs.append(padded_img)
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
padded_mask = torch.nn.functional.pad(
m, (0, padding[2], 0, padding[1]), "constant", 1
)
padded_masks.append(padded_mask.to(torch.bool))
tensor = torch.stack(padded_imgs)
mask = torch.stack(padded_masks)
return NestedTensor(tensor, mask=mask)