FBAGSTM's picture
STM32 AI Experimentation Hub
747451d
# ---------------------------------------------------------------------------------------------
# Copyright (c) 2025 STMicroelectronics.
# All rights reserved.
#
# Copyright (c) 2018-2024 Oleg Sémery
#
# This software is licensed under terms that can be found in the LICENSE file in
# the root directory of this software component.
# If no LICENSE file comes with this software, it is provided AS-IS.
# Taken from: https://github.com/osmr/imgclsmob
# ---------------------------------------------------------------------------------------------
import torch
import torch.nn as nn
def channel_shuffle(x, groups):
"""
Channel shuffle operation from 'ShuffleNet: An Extremely Efficient Convolutional Neural
Network for Mobile Devices,' https://arxiv.org/abs/1707.01083.
Parameters:
----------
x : Tensor
Input tensor.
groups : int
Number of groups.
Returns:
-------
Tensor
Resulted tensor.
"""
batch, channels, height, width = x.size()
# assert (channels % groups == 0)
channels_per_group = channels // groups
x = x.view(batch, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(batch, channels, height, width)
return x
class ChannelShuffle(nn.Module):
"""
Channel shuffle layer. This is a wrapper over the same operation. It is designed to save the number of groups.
Parameters:
----------
channels : int
Number of channels.
groups : int
Number of groups.
"""
def __init__(self, channels, groups):
super(ChannelShuffle, self).__init__()
if channels % groups != 0:
raise ValueError("channels must be divisible by groups")
self.groups = groups
def forward(self, x):
return channel_shuffle(x, self.groups)
def __repr__(self):
s = "{name}(groups={groups})"
return s.format(name=self.__class__.__name__, groups=self.groups)