Spaces:
Running on Zero
Running on Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| def channel_shuffle(x, groups): | |
| """Channel Shuffle operation. | |
| This function enables cross-group information flow for multiple groups | |
| convolution layers. | |
| Args: | |
| x (Tensor): The input tensor. | |
| groups (int): The number of groups to divide the input tensor | |
| in the channel dimension. | |
| Returns: | |
| Tensor: The output tensor after channel shuffle operation. | |
| """ | |
| batch_size, num_channels, height, width = x.size() | |
| assert (num_channels % groups == 0), ('num_channels should be ' | |
| 'divisible by groups') | |
| channels_per_group = num_channels // groups | |
| x = x.view(batch_size, groups, channels_per_group, height, width) | |
| x = torch.transpose(x, 1, 2).contiguous() | |
| x = x.view(batch_size, groups * channels_per_group, height, width) | |
| return x | |