| from .module import Module | |
| from typing import Tuple, Union | |
| from torch import Tensor | |
| from torch.types import _size | |
| __all__ = ['Flatten', 'Unflatten'] | |
| class Flatten(Module): | |
| r""" | |
| Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`. | |
| Shape: | |
| - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,' | |
| where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any | |
| number of dimensions including none. | |
| - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`. | |
| Args: | |
| start_dim: first dim to flatten (default = 1). | |
| end_dim: last dim to flatten (default = -1). | |
| Examples:: | |
| >>> input = torch.randn(32, 1, 5, 5) | |
| >>> # With default parameters | |
| >>> m = nn.Flatten() | |
| >>> output = m(input) | |
| >>> output.size() | |
| torch.Size([32, 25]) | |
| >>> # With non-default parameters | |
| >>> m = nn.Flatten(0, 2) | |
| >>> output = m(input) | |
| >>> output.size() | |
| torch.Size([160, 5]) | |
| """ | |
| __constants__ = ['start_dim', 'end_dim'] | |
| start_dim: int | |
| end_dim: int | |
| def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None: | |
| super(Flatten, self).__init__() | |
| self.start_dim = start_dim | |
| self.end_dim = end_dim | |
| def forward(self, input: Tensor) -> Tensor: | |
| return input.flatten(self.start_dim, self.end_dim) | |
| def extra_repr(self) -> str: | |
| return 'start_dim={}, end_dim={}'.format( | |
| self.start_dim, self.end_dim | |
| ) | |
| class Unflatten(Module): | |
| r""" | |
| Unflattens a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`. | |
| * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can | |
| be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively. | |
| * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be | |
| a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape` | |
| (tuple of `(name, size)` tuples) for `NamedTensor` input. | |
| Shape: | |
| - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at | |
| dimension :attr:`dim` and :math:`*` means any number of dimensions including none. | |
| - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and | |
| :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`. | |
| Args: | |
| dim (Union[int, str]): Dimension to be unflattened | |
| unflattened_size (Union[torch.Size, Tuple, List, NamedShape]): New shape of the unflattened dimension | |
| Examples: | |
| >>> input = torch.randn(2, 50) | |
| >>> # With tuple of ints | |
| >>> m = nn.Sequential( | |
| >>> nn.Linear(50, 50), | |
| >>> nn.Unflatten(1, (2, 5, 5)) | |
| >>> ) | |
| >>> output = m(input) | |
| >>> output.size() | |
| torch.Size([2, 2, 5, 5]) | |
| >>> # With torch.Size | |
| >>> m = nn.Sequential( | |
| >>> nn.Linear(50, 50), | |
| >>> nn.Unflatten(1, torch.Size([2, 5, 5])) | |
| >>> ) | |
| >>> output = m(input) | |
| >>> output.size() | |
| torch.Size([2, 2, 5, 5]) | |
| >>> # With namedshape (tuple of tuples) | |
| >>> input = torch.randn(2, 50, names=('N', 'features')) | |
| >>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5))) | |
| >>> output = unflatten(input) | |
| >>> output.size() | |
| torch.Size([2, 2, 5, 5]) | |
| """ | |
| NamedShape = Tuple[Tuple[str, int]] | |
| __constants__ = ['dim', 'unflattened_size'] | |
| dim: Union[int, str] | |
| unflattened_size: Union[_size, NamedShape] | |
| def __init__(self, dim: Union[int, str], unflattened_size: Union[_size, NamedShape]) -> None: | |
| super(Unflatten, self).__init__() | |
| if isinstance(dim, int): | |
| self._require_tuple_int(unflattened_size) | |
| elif isinstance(dim, str): | |
| self._require_tuple_tuple(unflattened_size) | |
| else: | |
| raise TypeError("invalid argument type for dim parameter") | |
| self.dim = dim | |
| self.unflattened_size = unflattened_size | |
| def _require_tuple_tuple(self, input): | |
| if (isinstance(input, tuple)): | |
| for idx, elem in enumerate(input): | |
| if not isinstance(elem, tuple): | |
| raise TypeError("unflattened_size must be tuple of tuples, " + | |
| "but found element of type {} at pos {}".format(type(elem).__name__, idx)) | |
| return | |
| raise TypeError("unflattened_size must be a tuple of tuples, " + | |
| "but found type {}".format(type(input).__name__)) | |
| def _require_tuple_int(self, input): | |
| if (isinstance(input, (tuple, list))): | |
| for idx, elem in enumerate(input): | |
| if not isinstance(elem, int): | |
| raise TypeError("unflattened_size must be tuple of ints, " + | |
| "but found element of type {} at pos {}".format(type(elem).__name__, idx)) | |
| return | |
| raise TypeError("unflattened_size must be a tuple of ints, but found type {}".format(type(input).__name__)) | |
| def forward(self, input: Tensor) -> Tensor: | |
| return input.unflatten(self.dim, self.unflattened_size) | |
| def extra_repr(self) -> str: | |
| return 'dim={}, unflattened_size={}'.format(self.dim, self.unflattened_size) | |