| | |
| | """ |
| | Wrappers around on some nn functions, mainly to support empty tensors. |
| | |
| | Ideally, add support directly in PyTorch to empty tensors in those functions. |
| | |
| | These can be removed once https://github.com/pytorch/pytorch/issues/12013 |
| | is implemented |
| | """ |
| |
|
| | import warnings |
| | from typing import List, Optional |
| | import torch |
| | from torch.nn import functional as F |
| |
|
| | from detectron2.utils.env import TORCH_VERSION |
| |
|
| |
|
| | def shapes_to_tensor(x: List[int], device: Optional[torch.device] = None) -> torch.Tensor: |
| | """ |
| | Turn a list of integer scalars or integer Tensor scalars into a vector, |
| | in a way that's both traceable and scriptable. |
| | |
| | In tracing, `x` should be a list of scalar Tensor, so the output can trace to the inputs. |
| | In scripting or eager, `x` should be a list of int. |
| | """ |
| | if torch.jit.is_scripting(): |
| | return torch.as_tensor(x, device=device) |
| | if torch.jit.is_tracing(): |
| | assert all( |
| | [isinstance(t, torch.Tensor) for t in x] |
| | ), "Shape should be tensor during tracing!" |
| | |
| | ret = torch.stack(x) |
| | if ret.device != device: |
| | ret = ret.to(device=device) |
| | return ret |
| | return torch.as_tensor(x, device=device) |
| |
|
| |
|
| | def check_if_dynamo_compiling(): |
| | if TORCH_VERSION >= (1, 14): |
| | from torch._dynamo import is_compiling |
| |
|
| | return is_compiling() |
| | else: |
| | return False |
| |
|
| |
|
| | def cat(tensors: List[torch.Tensor], dim: int = 0): |
| | """ |
| | Efficient version of torch.cat that avoids a copy if there is only a single element in a list |
| | """ |
| | assert isinstance(tensors, (list, tuple)) |
| | if len(tensors) == 1: |
| | return tensors[0] |
| | return torch.cat(tensors, dim) |
| |
|
| |
|
| | def empty_input_loss_func_wrapper(loss_func): |
| | def wrapped_loss_func(input, target, *, reduction="mean", **kwargs): |
| | """ |
| | Same as `loss_func`, but returns 0 (instead of nan) for empty inputs. |
| | """ |
| | if target.numel() == 0 and reduction == "mean": |
| | return input.sum() * 0.0 |
| | return loss_func(input, target, reduction=reduction, **kwargs) |
| |
|
| | return wrapped_loss_func |
| |
|
| |
|
| | cross_entropy = empty_input_loss_func_wrapper(F.cross_entropy) |
| |
|
| |
|
| | class _NewEmptyTensorOp(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, x, new_shape): |
| | ctx.shape = x.shape |
| | return x.new_empty(new_shape) |
| |
|
| | @staticmethod |
| | def backward(ctx, grad): |
| | shape = ctx.shape |
| | return _NewEmptyTensorOp.apply(grad, shape), None |
| |
|
| |
|
| | class Conv2d(torch.nn.Conv2d): |
| | """ |
| | A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. |
| | """ |
| |
|
| | def __init__(self, *args, **kwargs): |
| | """ |
| | Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: |
| | |
| | Args: |
| | norm (nn.Module, optional): a normalization layer |
| | activation (callable(Tensor) -> Tensor): a callable activation function |
| | |
| | It assumes that norm layer is used before activation. |
| | """ |
| | norm = kwargs.pop("norm", None) |
| | activation = kwargs.pop("activation", None) |
| | super().__init__(*args, **kwargs) |
| |
|
| | self.norm = norm |
| | self.activation = activation |
| |
|
| | def forward(self, x): |
| | |
| | |
| | |
| | |
| | |
| | |
| | if not torch.jit.is_scripting(): |
| | |
| | is_dynamo_compiling = check_if_dynamo_compiling() |
| | if not is_dynamo_compiling: |
| | with warnings.catch_warnings(record=True): |
| | if x.numel() == 0 and self.training: |
| | |
| | assert not isinstance( |
| | self.norm, torch.nn.SyncBatchNorm |
| | ), "SyncBatchNorm does not support empty inputs!" |
| |
|
| | x = F.conv2d( |
| | x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups |
| | ) |
| | if self.norm is not None: |
| | x = self.norm(x) |
| | if self.activation is not None: |
| | x = self.activation(x) |
| | return x |
| |
|
| |
|
| | ConvTranspose2d = torch.nn.ConvTranspose2d |
| | BatchNorm2d = torch.nn.BatchNorm2d |
| | interpolate = F.interpolate |
| | Linear = torch.nn.Linear |
| |
|
| |
|
| | def nonzero_tuple(x): |
| | """ |
| | A 'as_tuple=True' version of torch.nonzero to support torchscript. |
| | because of https://github.com/pytorch/pytorch/issues/38718 |
| | """ |
| | if torch.jit.is_scripting(): |
| | if x.dim() == 0: |
| | return x.unsqueeze(0).nonzero().unbind(1) |
| | return x.nonzero().unbind(1) |
| | else: |
| | return x.nonzero(as_tuple=True) |
| |
|
| |
|
| | @torch.jit.script_if_tracing |
| | def move_device_like(src: torch.Tensor, dst: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Tracing friendly way to cast tensor to another tensor's device. Device will be treated |
| | as constant during tracing, scripting the casting process as whole can workaround this issue. |
| | """ |
| | return src.to(dst.device) |
| |
|