Spaces:
Build error
Build error
| import math | |
| from functools import partial | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| class SwishImplementation(torch.autograd.Function): | |
| def forward(ctx, i): | |
| result = i * torch.sigmoid(i) | |
| ctx.save_for_backward(i) | |
| return result | |
| def backward(ctx, grad_output): | |
| i = ctx.saved_variables[0] | |
| sigmoid_i = torch.sigmoid(i) | |
| return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) | |
| class MemoryEfficientSwish(nn.Module): | |
| def forward(self, x): | |
| return SwishImplementation.apply(x) | |
| def drop_connect(inputs, p, training): | |
| """ Drop connect. """ | |
| if not training: return inputs | |
| batch_size = inputs.shape[0] | |
| keep_prob = 1 - p | |
| random_tensor = keep_prob | |
| random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) | |
| binary_tensor = torch.floor(random_tensor) | |
| output = inputs / keep_prob * binary_tensor | |
| return output | |
| def get_same_padding_conv2d(image_size=None): | |
| return partial(Conv2dStaticSamePadding, image_size=image_size) | |
| def get_width_and_height_from_size(x): | |
| """ Obtains width and height from a int or tuple """ | |
| if isinstance(x, int): return x, x | |
| if isinstance(x, list) or isinstance(x, tuple): return x | |
| else: raise TypeError() | |
| def calculate_output_image_size(input_image_size, stride): | |
| """ | |
| 计算出 Conv2dSamePadding with a stride. | |
| """ | |
| if input_image_size is None: return None | |
| image_height, image_width = get_width_and_height_from_size(input_image_size) | |
| stride = stride if isinstance(stride, int) else stride[0] | |
| image_height = int(math.ceil(image_height / stride)) | |
| image_width = int(math.ceil(image_width / stride)) | |
| return [image_height, image_width] | |
| class Conv2dStaticSamePadding(nn.Conv2d): | |
| """ 2D Convolutions like TensorFlow, for a fixed image size""" | |
| def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs): | |
| super().__init__(in_channels, out_channels, kernel_size, **kwargs) | |
| self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 | |
| # Calculate padding based on image size and save it | |
| assert image_size is not None | |
| ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size | |
| kh, kw = self.weight.size()[-2:] | |
| sh, sw = self.stride | |
| oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) | |
| pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) | |
| pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) | |
| if pad_h > 0 or pad_w > 0: | |
| self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) | |
| else: | |
| self.static_padding = Identity() | |
| def forward(self, x): | |
| x = self.static_padding(x) | |
| x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) | |
| return x | |
| class Identity(nn.Module): | |
| def __init__(self, ): | |
| super(Identity, self).__init__() | |
| def forward(self, input): | |
| return input | |
| # #MBConvBlock | |
| class MBConvBlock(nn.Module): | |
| ''' | |
| 层 ksize3*3 输入32 输出16 conv1 stride步长1 | |
| ''' | |
| def __init__(self, ksize, input_filters, output_filters, expand_ratio=1, stride=1,image_size=224,drop_connect_rate=0.): | |
| super().__init__() | |
| self._bn_mom = 0.1 | |
| self._bn_eps = 0.01 | |
| self._se_ratio = 0.25 | |
| self._input_filters = input_filters | |
| self._output_filters = output_filters | |
| self._expand_ratio = expand_ratio | |
| self._kernel_size = ksize | |
| self._stride = stride | |
| self._drop_connect_rate = drop_connect_rate | |
| inp = self._input_filters | |
| oup = self._input_filters * self._expand_ratio | |
| if self._expand_ratio != 1: | |
| self._expand_conv = nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=1,bias=False) | |
| self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) | |
| # Depthwise convolution | |
| k = self._kernel_size | |
| s = self._stride | |
| self._depthwise_conv = nn.Conv2d(in_channels=oup, out_channels=oup, groups=oup, | |
| kernel_size=k, stride=s, padding=1,bias=False) | |
| self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) | |
| # Squeeze and Excitation layer, if desired | |
| num_squeezed_channels = max(1, int(self._input_filters * self._se_ratio)) | |
| self._se_reduce = nn.Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) | |
| self._se_expand = nn.Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) | |
| # Output phase | |
| final_oup = self._output_filters | |
| self._project_conv = nn.Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1,bias=False) | |
| self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) | |
| self._swish = MemoryEfficientSwish() | |
| def forward(self, inputs): | |
| """ | |
| :param inputs: input tensor | |
| :return: output of block | |
| """ | |
| # Expansion and Depthwise Convolution | |
| x = inputs | |
| if self._expand_ratio != 1: | |
| expand = self._expand_conv(inputs) | |
| bn0 = self._bn0(expand) | |
| x = self._swish(bn0) | |
| depthwise = self._depthwise_conv(x) | |
| bn1 = self._bn1(depthwise) | |
| x = self._swish(bn1) | |
| # Squeeze and Excitation | |
| x_squeezed = F.adaptive_avg_pool2d(x, 1) | |
| x_squeezed = self._se_reduce(x_squeezed) | |
| x_squeezed = self._swish(x_squeezed) | |
| x_squeezed = self._se_expand(x_squeezed) | |
| x = torch.sigmoid(x_squeezed) * x | |
| x = self._bn2(self._project_conv(x)) | |
| # Skip connection and drop connect | |
| input_filters, output_filters = self._input_filters, self._output_filters | |
| if self._stride == 1 and input_filters == output_filters: | |
| if self._drop_connect_rate!=0: | |
| x = drop_connect(x, p=self._drop_connect_rate, training=self.training) | |
| x = x + inputs # skip connection | |
| return x | |
| if __name__ == '__main__': | |
| input=torch.randn(1,3,112,112) | |
| mbconv=MBConvBlock(ksize=3,input_filters=3,output_filters=3,expand_ratio=4,stride=1) | |
| print(mbconv) | |
| out=mbconv(input) | |
| print(out.shape) |