Spaces:
Runtime error
Runtime error
| from torch import nn | |
| class Conv(nn.Module): | |
| def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=False, relu=True): | |
| super(Conv, self).__init__() | |
| self.inp_dim = inp_dim | |
| self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size - 1) // 2, bias=False) | |
| self.relu = None | |
| self.bn = None | |
| if relu: | |
| self.relu = nn.ReLU() | |
| if bn: | |
| self.bn = nn.BatchNorm2d(out_dim) | |
| def forward(self, x): | |
| assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim) | |
| x = self.conv(x) | |
| if self.bn is not None: | |
| x = self.bn(x) | |
| if self.relu is not None: | |
| x = self.relu(x) | |
| return x | |
| class Deconv(nn.Module): | |
| def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=False, relu=True): | |
| super(Deconv, self).__init__() | |
| self.inp_dim = inp_dim | |
| self.deconv = nn.ConvTranspose2d(inp_dim, out_dim, kernel_size=kernel_size, stride=stride, bias=False) | |
| self.relu = None | |
| self.bn = None | |
| if relu: | |
| self.relu = nn.ReLU() | |
| if bn: | |
| self.bn = nn.BatchNorm2d(out_dim) | |
| def forward(self, x): | |
| assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim) | |
| x = self.deconv(x) | |
| if self.bn is not None: | |
| x = self.bn(x) | |
| if self.relu is not None: | |
| x = self.relu(x) | |
| return x | |
| class Residual(nn.Module): | |
| def __init__(self, inp_dim, out_dim, kernel=3): | |
| super(Residual, self).__init__() | |
| self.relu = nn.ReLU() | |
| self.bn1 = nn.BatchNorm2d(inp_dim) | |
| self.conv1 = Conv(inp_dim, int(out_dim / 2), 1, relu=False) | |
| self.bn2 = nn.BatchNorm2d(int(out_dim / 2)) | |
| self.conv2 = Conv(int(out_dim / 2), int(out_dim / 2), kernel, relu=False) | |
| self.bn3 = nn.BatchNorm2d(int(out_dim / 2)) | |
| self.conv3 = Conv(int(out_dim / 2), out_dim, 1, relu=False) | |
| self.skip_layer = Conv(inp_dim, out_dim, 1, relu=False) | |
| if inp_dim == out_dim: | |
| self.need_skip = False | |
| else: | |
| self.need_skip = True | |
| def forward(self, x): | |
| if self.need_skip: | |
| residual = self.skip_layer(x) | |
| else: | |
| residual = x | |
| out = self.bn1(x) | |
| out = self.relu(out) | |
| out = self.conv1(out) | |
| out = self.bn2(out) | |
| out = self.relu(out) | |
| out = self.conv2(out) | |
| out = self.bn3(out) | |
| out = self.relu(out) | |
| out = self.conv3(out) | |
| out += residual | |
| return out | |