|
|
import math |
|
|
from functools import partial |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
|
|
|
class SwishImplementation(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward(ctx, i): |
|
|
result = i * torch.sigmoid(i) |
|
|
ctx.save_for_backward(i) |
|
|
return result |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
i = ctx.saved_variables[0] |
|
|
sigmoid_i = torch.sigmoid(i) |
|
|
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) |
|
|
|
|
|
class MemoryEfficientSwish(nn.Module): |
|
|
def forward(self, x): |
|
|
return SwishImplementation.apply(x) |
|
|
|
|
|
|
|
|
def drop_connect(inputs, p, training): |
|
|
""" Drop connect. """ |
|
|
if not training: return inputs |
|
|
batch_size = inputs.shape[0] |
|
|
keep_prob = 1 - p |
|
|
random_tensor = keep_prob |
|
|
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) |
|
|
binary_tensor = torch.floor(random_tensor) |
|
|
output = inputs / keep_prob * binary_tensor |
|
|
return output |
|
|
|
|
|
|
|
|
def get_same_padding_conv2d(image_size=None): |
|
|
return partial(Conv2dStaticSamePadding, image_size=image_size) |
|
|
|
|
|
def get_width_and_height_from_size(x): |
|
|
""" Obtains width and height from a int or tuple """ |
|
|
if isinstance(x, int): return x, x |
|
|
if isinstance(x, list) or isinstance(x, tuple): return x |
|
|
else: raise TypeError() |
|
|
|
|
|
def calculate_output_image_size(input_image_size, stride): |
|
|
""" |
|
|
计算出 Conv2dSamePadding with a stride. |
|
|
""" |
|
|
if input_image_size is None: return None |
|
|
image_height, image_width = get_width_and_height_from_size(input_image_size) |
|
|
stride = stride if isinstance(stride, int) else stride[0] |
|
|
image_height = int(math.ceil(image_height / stride)) |
|
|
image_width = int(math.ceil(image_width / stride)) |
|
|
return [image_height, image_width] |
|
|
|
|
|
|
|
|
|
|
|
class Conv2dStaticSamePadding(nn.Conv2d): |
|
|
""" 2D Convolutions like TensorFlow, for a fixed image size""" |
|
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs): |
|
|
super().__init__(in_channels, out_channels, kernel_size, **kwargs) |
|
|
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 |
|
|
|
|
|
|
|
|
assert image_size is not None |
|
|
ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size |
|
|
kh, kw = self.weight.size()[-2:] |
|
|
sh, sw = self.stride |
|
|
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) |
|
|
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) |
|
|
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) |
|
|
if pad_h > 0 or pad_w > 0: |
|
|
self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) |
|
|
else: |
|
|
self.static_padding = Identity() |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.static_padding(x) |
|
|
x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) |
|
|
return x |
|
|
|
|
|
class Identity(nn.Module): |
|
|
def __init__(self, ): |
|
|
super(Identity, self).__init__() |
|
|
|
|
|
def forward(self, input): |
|
|
return input |
|
|
|
|
|
|
|
|
class MBConvBlock(nn.Module): |
|
|
''' |
|
|
层 ksize3*3 输入32 输出16 conv1 stride步长1 |
|
|
''' |
|
|
def __init__(self, ksize, input_filters, output_filters, expand_ratio=1, stride=1,image_size=224,drop_connect_rate=0.): |
|
|
super().__init__() |
|
|
self._bn_mom = 0.1 |
|
|
self._bn_eps = 0.01 |
|
|
self._se_ratio = 0.25 |
|
|
self._input_filters = input_filters |
|
|
self._output_filters = output_filters |
|
|
self._expand_ratio = expand_ratio |
|
|
self._kernel_size = ksize |
|
|
self._stride = stride |
|
|
self._drop_connect_rate = drop_connect_rate |
|
|
inp = self._input_filters |
|
|
oup = self._input_filters * self._expand_ratio |
|
|
if self._expand_ratio != 1: |
|
|
self._expand_conv = nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=1,bias=False) |
|
|
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) |
|
|
|
|
|
|
|
|
|
|
|
k = self._kernel_size |
|
|
s = self._stride |
|
|
self._depthwise_conv = nn.Conv2d(in_channels=oup, out_channels=oup, groups=oup, |
|
|
kernel_size=k, stride=s, padding=1,bias=False) |
|
|
|
|
|
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) |
|
|
|
|
|
num_squeezed_channels = max(1, int(self._input_filters * self._se_ratio)) |
|
|
self._se_reduce = nn.Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) |
|
|
self._se_expand = nn.Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) |
|
|
|
|
|
|
|
|
final_oup = self._output_filters |
|
|
self._project_conv = nn.Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1,bias=False) |
|
|
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) |
|
|
self._swish = MemoryEfficientSwish() |
|
|
|
|
|
def forward(self, inputs): |
|
|
""" |
|
|
:param inputs: input tensor |
|
|
:return: output of block |
|
|
""" |
|
|
|
|
|
|
|
|
x = inputs |
|
|
if self._expand_ratio != 1: |
|
|
expand = self._expand_conv(inputs) |
|
|
bn0 = self._bn0(expand) |
|
|
x = self._swish(bn0) |
|
|
depthwise = self._depthwise_conv(x) |
|
|
bn1 = self._bn1(depthwise) |
|
|
x = self._swish(bn1) |
|
|
|
|
|
x_squeezed = F.adaptive_avg_pool2d(x, 1) |
|
|
x_squeezed = self._se_reduce(x_squeezed) |
|
|
x_squeezed = self._swish(x_squeezed) |
|
|
x_squeezed = self._se_expand(x_squeezed) |
|
|
x = torch.sigmoid(x_squeezed) * x |
|
|
|
|
|
x = self._bn2(self._project_conv(x)) |
|
|
|
|
|
|
|
|
input_filters, output_filters = self._input_filters, self._output_filters |
|
|
if self._stride == 1 and input_filters == output_filters: |
|
|
if self._drop_connect_rate!=0: |
|
|
x = drop_connect(x, p=self._drop_connect_rate, training=self.training) |
|
|
x = x + inputs |
|
|
return x |
|
|
if __name__ == '__main__': |
|
|
input=torch.randn(1,3,112,112) |
|
|
mbconv=MBConvBlock(ksize=3,input_filters=3,output_filters=3,expand_ratio=4,stride=1) |
|
|
print(mbconv) |
|
|
out=mbconv(input) |
|
|
print(out.shape) |