test / model_utils.py
danicor's picture
Create model_utils.py
0e96e5f verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class unetConv2(nn.Module):
def __init__(self, in_size, out_size, is_batchnorm):
super(unetConv2, self).__init__()
if is_batchnorm:
self.conv1 = nn.Sequential(
nn.Conv2d(in_size, out_size, 3, 1, 1),
nn.BatchNorm2d(out_size),
nn.ReLU(),
)
self.conv2 = nn.Sequential(
nn.Conv2d(out_size, out_size, 3, 1, 1),
nn.BatchNorm2d(out_size),
nn.ReLU(),
)
else:
self.conv1 = nn.Sequential(
nn.Conv2d(in_size, out_size, 3, 1, 1), nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(out_size, out_size, 3, 1, 1), nn.ReLU()
)
def forward(self, inputs):
outputs = self.conv1(inputs)
outputs = self.conv2(outputs)
return outputs
class unetUp(nn.Module):
def __init__(self, in_size, out_size, is_deconv, is_batchnorm):
super(unetUp, self).__init__()
self.conv = unetConv2(in_size, out_size, is_batchnorm)
if is_deconv:
self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
else:
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
def forward(self, inputs1, inputs2):
outputs2 = self.up(inputs2)
offset = outputs2.size()[2] - inputs1.size()[2]
padding = 2 * [offset // 2, offset // 2]
outputs1 = F.pad(inputs1, padding)
return self.conv(torch.cat([outputs1, outputs2], 1))