| import torch | |
| from torch import nn | |
| import numpy as np | |
| class DepthwiseSeperableConv2d(nn.Module): | |
| def __init__(self, input_channels, output_channels, **kwargs): | |
| super(DepthwiseSeperableConv2d, self).__init__() | |
| self.depthwise = nn.Conv2d(input_channels, input_channels, groups = input_channels, **kwargs) | |
| self.pointwise = nn.Conv2d(input_channels, output_channels, kernel_size = 1) | |
| def forward(self, x): | |
| x = self.depthwise(x) | |
| x = self.pointwise(x) | |
| return x | |
| class Conv2dBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride = 1, bias = False): | |
| super(Conv2dBlock, self).__init__() | |
| self.model = nn.Sequential( | |
| nn.ReflectionPad2d(int((kernel_size - 1) / 2)), | |
| DepthwiseSeperableConv2d(in_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = 0, bias = bias), | |
| nn.BatchNorm2d(out_channels), | |
| nn.LeakyReLU(0.2) | |
| ) | |
| def forward(self, x): | |
| return self.model(x) | |
| class Concat(nn.Module): | |
| def __init__(self, dim, *args): | |
| super(Concat, self).__init__() | |
| self.dim = dim | |
| for idx, module in enumerate(args): | |
| self.add_module(str(idx), module) | |
| def forward(self, input): | |
| inputs = [] | |
| for module in self._modules.values(): | |
| inputs.append(module(input)) | |
| inputs_shapes2 = [x.shape[2] for x in inputs] | |
| inputs_shapes3 = [x.shape[3] for x in inputs] | |
| if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)): | |
| inputs_ = inputs | |
| else: | |
| target_shape2 = min(inputs_shapes2) | |
| target_shape3 = min(inputs_shapes3) | |
| inputs_ = [] | |
| for inp in inputs: | |
| diff2 = (inp.size(2) - target_shape2) // 2 | |
| diff3 = (inp.size(3) - target_shape3) // 2 | |
| inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3]) | |
| return torch.cat(inputs_, dim=self.dim) | |
| def __len__(self): | |
| return len(self._modules) |