| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class unetConv2(nn.Module): | |
| def __init__(self, in_size, out_size, is_batchnorm): | |
| super(unetConv2, self).__init__() | |
| if is_batchnorm: | |
| self.conv1 = nn.Sequential( | |
| nn.Conv2d(in_size, out_size, 3, 1, 1), | |
| nn.BatchNorm2d(out_size), | |
| nn.ReLU(), | |
| ) | |
| self.conv2 = nn.Sequential( | |
| nn.Conv2d(out_size, out_size, 3, 1, 1), | |
| nn.BatchNorm2d(out_size), | |
| nn.ReLU(), | |
| ) | |
| else: | |
| self.conv1 = nn.Sequential( | |
| nn.Conv2d(in_size, out_size, 3, 1, 1), nn.ReLU() | |
| ) | |
| self.conv2 = nn.Sequential( | |
| nn.Conv2d(out_size, out_size, 3, 1, 1), nn.ReLU() | |
| ) | |
| def forward(self, inputs): | |
| outputs = self.conv1(inputs) | |
| outputs = self.conv2(outputs) | |
| return outputs | |
| class unetUp(nn.Module): | |
| def __init__(self, in_size, out_size, is_deconv, is_batchnorm): | |
| super(unetUp, self).__init__() | |
| self.conv = unetConv2(in_size, out_size, is_batchnorm) | |
| if is_deconv: | |
| self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2) | |
| else: | |
| self.up = nn.UpsamplingBilinear2d(scale_factor=2) | |
| def forward(self, inputs1, inputs2): | |
| outputs2 = self.up(inputs2) | |
| offset = outputs2.size()[2] - inputs1.size()[2] | |
| padding = 2 * [offset // 2, offset // 2] | |
| outputs1 = F.pad(inputs1, padding) | |
| return self.conv(torch.cat([outputs1, outputs2], 1)) |