| |
| |
| """ |
| Misc functions, including distributed helpers. |
| |
| Mostly copy-paste from torchvision references. |
| """ |
| from typing import List, Optional |
|
|
| import torch |
| import torch.distributed as dist |
| import torchvision |
| from torch import Tensor |
| import warnings |
| import torch.nn.functional as F |
| import math |
|
|
| def inverse_sigmoid(x, eps=1e-3): |
| x = x.clamp(min=0, max=1) |
| x1 = x.clamp(min=eps) |
| x2 = (1 - x).clamp(min=eps) |
| return torch.log(x1/x2) |
|
|
| def _no_grad_trunc_normal_(tensor, mean, std, a, b): |
| |
| |
| def norm_cdf(x): |
| |
| return (1. + math.erf(x / math.sqrt(2.))) / 2. |
|
|
| if (mean < a - 2 * std) or (mean > b + 2 * std): |
| warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
| "The distribution of values may be incorrect.", |
| stacklevel=2) |
|
|
| with torch.no_grad(): |
| |
| |
| |
| l = norm_cdf((a - mean) / std) |
| u = norm_cdf((b - mean) / std) |
|
|
| |
| |
| tensor.uniform_(2 * l - 1, 2 * u - 1) |
|
|
| |
| |
| tensor.erfinv_() |
|
|
| |
| tensor.mul_(std * math.sqrt(2.)) |
| tensor.add_(mean) |
|
|
| |
| tensor.clamp_(min=a, max=b) |
| return tensor |
|
|
| def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): |
| |
| r"""Fills the input Tensor with values drawn from a truncated |
| normal distribution. The values are effectively drawn from the |
| normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` |
| with values outside :math:`[a, b]` redrawn until they are within |
| the bounds. The method used for generating the random values works |
| best when :math:`a \leq \text{mean} \leq b`. |
| Args: |
| tensor: an n-dimensional `torch.Tensor` |
| mean: the mean of the normal distribution |
| std: the standard deviation of the normal distribution |
| a: the minimum cutoff value |
| b: the maximum cutoff value |
| Examples: |
| >>> w = torch.empty(3, 5) |
| >>> nn.init.trunc_normal_(w) |
| """ |
| return _no_grad_trunc_normal_(tensor, mean, std, a, b) |
|
|
| def resize(input, |
| size=None, |
| scale_factor=None, |
| mode='nearest', |
| align_corners=None, |
| warning=True): |
| if warning: |
| if size is not None and align_corners: |
| input_h, input_w = tuple(int(x) for x in input.shape[2:]) |
| output_h, output_w = tuple(int(x) for x in size) |
| if output_h > input_h or output_w > output_h: |
| if ((output_h > 1 and output_w > 1 and input_h > 1 |
| and input_w > 1) and (output_h - 1) % (input_h - 1) |
| and (output_w - 1) % (input_w - 1)): |
| warnings.warn( |
| f'When align_corners={align_corners}, ' |
| 'the output would more aligned if ' |
| f'input size {(input_h, input_w)} is `x+1` and ' |
| f'out size {(output_h, output_w)} is `nx+1`') |
| if isinstance(size, torch.Size): |
| size = tuple(int(x) for x in size) |
| return F.interpolate(input, size, scale_factor, mode, align_corners) |
|
|
| def _max_by_axis(the_list): |
| |
| maxes = the_list[0] |
| for sublist in the_list[1:]: |
| for index, item in enumerate(sublist): |
| maxes[index] = max(maxes[index], item) |
| return maxes |
|
|
|
|
| class NestedTensor(object): |
| def __init__(self, tensors, mask: Optional[Tensor]): |
| self.tensors = tensors |
| self.mask = mask |
|
|
| def to(self, device): |
| |
| cast_tensor = self.tensors.to(device) |
| mask = self.mask |
| if mask is not None: |
| assert mask is not None |
| cast_mask = mask.to(device) |
| else: |
| cast_mask = None |
| return NestedTensor(cast_tensor, cast_mask) |
|
|
| def decompose(self): |
| return self.tensors, self.mask |
|
|
| def __repr__(self): |
| return str(self.tensors) |
|
|
|
|
| def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): |
| |
| if tensor_list[0].ndim == 3: |
| if torchvision._is_tracing(): |
| |
| |
| return _onnx_nested_tensor_from_tensor_list(tensor_list) |
|
|
| |
| max_size = _max_by_axis([list(img.shape) for img in tensor_list]) |
| |
| batch_shape = [len(tensor_list)] + max_size |
| b, c, h, w = batch_shape |
| dtype = tensor_list[0].dtype |
| device = tensor_list[0].device |
| tensor = torch.zeros(batch_shape, dtype=dtype, device=device) |
| mask = torch.ones((b, h, w), dtype=torch.bool, device=device) |
| for img, pad_img, m in zip(tensor_list, tensor, mask): |
| pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) |
| m[: img.shape[1], : img.shape[2]] = False |
| else: |
| raise ValueError("not supported") |
| return NestedTensor(tensor, mask) |
|
|
|
|
| |
| |
| @torch.jit.unused |
| def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: |
| max_size = [] |
| for i in range(tensor_list[0].dim()): |
| max_size_i = torch.max( |
| torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) |
| ).to(torch.int64) |
| max_size.append(max_size_i) |
| max_size = tuple(max_size) |
|
|
| |
| |
| |
| |
| padded_imgs = [] |
| padded_masks = [] |
| for img in tensor_list: |
| padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] |
| padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) |
| padded_imgs.append(padded_img) |
|
|
| m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) |
| padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) |
| padded_masks.append(padded_mask.to(torch.bool)) |
|
|
| tensor = torch.stack(padded_imgs) |
| mask = torch.stack(padded_masks) |
|
|
| return NestedTensor(tensor, mask=mask) |
|
|
|
|
| def is_dist_avail_and_initialized(): |
| if not dist.is_available(): |
| return False |
| if not dist.is_initialized(): |
| return False |
| return True |
|
|