| import torch.nn.functional as F | |
| from torch import Tensor | |
| from .module import Module | |
| __all__ = ["ChannelShuffle"] | |
| class ChannelShuffle(Module): | |
| r"""Divides and rearranges the channels in a tensor. | |
| This operation divides the channels in a tensor of shape :math:`(N, C, *)` | |
| into g groups as :math:`(N, \frac{C}{g}, g, *)` and shuffles them, | |
| while retaining the original tensor shape in the final output. | |
| Args: | |
| groups (int): number of groups to divide channels in. | |
| Examples:: | |
| >>> channel_shuffle = nn.ChannelShuffle(2) | |
| >>> input = torch.arange(1, 17, dtype=torch.float32).view(1, 4, 2, 2) | |
| >>> input | |
| tensor([[[[ 1., 2.], | |
| [ 3., 4.]], | |
| [[ 5., 6.], | |
| [ 7., 8.]], | |
| [[ 9., 10.], | |
| [11., 12.]], | |
| [[13., 14.], | |
| [15., 16.]]]]) | |
| >>> output = channel_shuffle(input) | |
| >>> output | |
| tensor([[[[ 1., 2.], | |
| [ 3., 4.]], | |
| [[ 9., 10.], | |
| [11., 12.]], | |
| [[ 5., 6.], | |
| [ 7., 8.]], | |
| [[13., 14.], | |
| [15., 16.]]]]) | |
| """ | |
| __constants__ = ["groups"] | |
| groups: int | |
| def __init__(self, groups: int) -> None: | |
| super().__init__() | |
| self.groups = groups | |
| def forward(self, input: Tensor) -> Tensor: | |
| return F.channel_shuffle(input, self.groups) | |
| def extra_repr(self) -> str: | |
| return f"groups={self.groups}" | |