| |
| import torch |
|
|
|
|
| def channel_shuffle(x, groups): |
| """Channel Shuffle operation. |
| |
| This function enables cross-group information flow for multiple groups |
| convolution layers. |
| |
| Args: |
| x (Tensor): The input tensor. |
| groups (int): The number of groups to divide the input tensor |
| in the channel dimension. |
| |
| Returns: |
| Tensor: The output tensor after channel shuffle operation. |
| """ |
|
|
| batch_size, num_channels, height, width = x.size() |
| assert (num_channels % groups == 0), ('num_channels should be ' |
| 'divisible by groups') |
| channels_per_group = num_channels // groups |
|
|
| x = x.view(batch_size, groups, channels_per_group, height, width) |
| x = torch.transpose(x, 1, 2).contiguous() |
| x = x.view(batch_size, -1, height, width) |
|
|
| return x |
|
|