| import math
|
|
|
| import torch
|
| import torch.nn.functional as F
|
| import torch.utils.model_zoo as model_zoo
|
| from torch import nn
|
| from torch.nn import Parameter
|
| import pdb
|
| import numpy as np
|
|
|
|
|
| class ResidualBlock(nn.Module):
|
| def __init__(self, inchannel, outchannel, stride=1):
|
| super(ResidualBlock, self).__init__()
|
| self.left = nn.Sequential(
|
| nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
|
| nn.BatchNorm2d(outchannel),
|
| nn.ReLU(inplace=True),
|
| nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
|
| nn.BatchNorm2d(outchannel)
|
| )
|
| self.shortcut = nn.Sequential()
|
| if stride != 1 or inchannel != outchannel:
|
| self.shortcut = nn.Sequential(
|
| nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
|
| nn.BatchNorm2d(outchannel)
|
| )
|
|
|
| def forward(self, x):
|
| out = self.left(x)
|
| out += self.shortcut(x)
|
| out = F.relu(out)
|
| return out
|
|
|
|
|
| class ResNet(nn.Module):
|
| def __init__(self, ResidualBlock = ResidualBlock):
|
| super(ResNet, self).__init__()
|
| self.inchannel = 64
|
| self.conv1 = nn.Sequential(
|
| nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
|
| nn.BatchNorm2d(64),
|
| nn.ReLU(),
|
| )
|
| self.layer1 = self.make_layer(ResidualBlock, 64, 2, stride=1)
|
| self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
|
| self.layer3 = self.make_layer(ResidualBlock, 196, 2, stride=2)
|
| self.layer4 = self.make_layer(ResidualBlock, 256, 2, stride=2)
|
| self.lastconv1 = nn.Sequential(
|
| nn.Conv2d(128 + 196 + 256, 128, kernel_size=3, stride=1, padding=1, bias=False),
|
| nn.BatchNorm2d(128),
|
| nn.ReLU(),
|
| )
|
| self.lastconv2 = nn.Sequential(
|
| nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1, bias=False),
|
| nn.BatchNorm2d(64),
|
| nn.ReLU(),
|
| )
|
| self.lastconv3 = nn.Sequential(
|
| nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1, bias=False),
|
| nn.ReLU(),
|
| )
|
| self.downsample32x32 = nn.Upsample(size=(32, 32), mode='bilinear')
|
|
|
| self.mu_head = nn.Sequential(
|
| nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1, bias=False),
|
|
|
| )
|
| self.logvar_head = nn.Sequential(
|
| nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1, bias=False),
|
|
|
| )
|
|
|
| def make_layer(self, block, channels, num_blocks, stride):
|
| strides = [stride] + [1] * (num_blocks - 1)
|
| layers = []
|
| for stride in strides:
|
| layers.append(block(self.inchannel, channels, stride))
|
| self.inchannel = channels
|
| return nn.Sequential(*layers)
|
|
|
| def _reparameterize(self, mu, logvar):
|
| std = torch.exp(logvar).sqrt()
|
| epsilon = torch.randn_like(std)
|
| return mu + epsilon * std
|
|
|
| def forward(self, x):
|
| x_input = x
|
| x = self.conv1(x)
|
| x_block1 = self.layer1(x)
|
| x_block2 = self.layer2(x_block1)
|
| x_block2_32 = self.downsample32x32(x_block2)
|
| x_block3 = self.layer3(x_block2)
|
| x_block3_32 = self.downsample32x32(x_block3)
|
| x_block4 = self.layer4(x_block3)
|
| x_block4_32 = self.downsample32x32(x_block4)
|
|
|
| x_concat = torch.cat((x_block2_32, x_block3_32, x_block4_32), dim=1)
|
|
|
| x = self.lastconv1(x_concat)
|
| x = self.lastconv2(x)
|
| mu = self.mu_head(x)
|
| mu = mu.squeeze(1)
|
| logvar = self.logvar_head(x)
|
| logvar = logvar.squeeze(1)
|
| embedding = self._reparameterize(mu, logvar)
|
|
|
| return mu, logvar, embedding, x_concat, x_block2, x_block3, x_block4, x_input
|
|
|
|
|
| def ResNet18_u():
|
|
|
| return ResNet(ResidualBlock)
|
|
|
|
|