| from torch import nn
|
| from torch.autograd import Variable
|
| import torch
|
| import torch.nn.functional as F
|
|
|
| import torchvision.models as models
|
|
|
| def add_gaussian_noise(ins, mean=0, stddev=0.1):
|
| noise = ins.data.new(ins.size()).normal_(mean, stddev)
|
| return ins + noise
|
|
|
| class FlattenLayer(nn.Module):
|
| def __init__(self):
|
| super(FlattenLayer, self).__init__()
|
|
|
| def forward(self, x):
|
| return x.view(x.size(0), -1)
|
|
|
|
|
| class UnflattenLayer(nn.Module):
|
| def __init__(self, width):
|
| super(UnflattenLayer, self).__init__()
|
| self.width = width
|
|
|
| def forward(self, x):
|
| return x.view(x.size(0), -1, self.width, self.width)
|
|
|
| class VAE_Encoder(nn.Module):
|
| '''
|
| VAE_Encoder: Encode image into std and logvar
|
| '''
|
|
|
| def __init__(self, latent_dim=256):
|
| super(VAE_Encoder, self).__init__()
|
| self.resnet = models.resnet18(pretrained=True)
|
| self.resnet.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| self.resnet = nn.Sequential(
|
| *list(self.resnet.children())[:-1],
|
| FlattenLayer()
|
| )
|
|
|
| self.l_mu = nn.Linear(512, latent_dim)
|
| self.l_var = nn.Linear(512, latent_dim)
|
|
|
| def encode(self, x):
|
| hidden = self.resnet(x)
|
| mu = self.l_mu(hidden)
|
| logvar = self.l_var(hidden)
|
| return mu, logvar
|
|
|
| def reparameterize(self, mu, logvar):
|
| if self.training:
|
| std = torch.exp(0.5*logvar)
|
| eps = torch.randn_like(std)
|
| return mu + eps*std
|
|
|
| else:
|
| return mu
|
|
|
| def forward(self, x):
|
| mu, logvar = self.encode(x)
|
| z = self.reparameterize(mu, logvar)
|
| return z, mu, logvar
|
|
|
|
|
| class VAE_Decoder(nn.Module):
|
| '''
|
| VAE_Decoder: Decode noise to image
|
| '''
|
|
|
| def __init__(self, latent_dim, output_dim=3):
|
| super(VAE_Decoder, self).__init__()
|
| self.convs = nn.Sequential(
|
| UnflattenLayer(width=1),
|
| nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
|
| nn.ReLU(inplace=True),
|
| nn.ConvTranspose2d(512, 384, 4, 2, 1, bias=False),
|
| nn.BatchNorm2d(384),
|
| nn.ReLU(inplace=True),
|
| nn.ConvTranspose2d(384, 192, 4, 2, 1, bias=False),
|
| nn.BatchNorm2d(192),
|
| nn.ReLU(inplace=True),
|
| nn.ConvTranspose2d(192, 96, 4, 2, 1, bias=False),
|
| nn.BatchNorm2d(96),
|
| nn.ReLU(inplace=True),
|
| nn.ConvTranspose2d(96, 64, 4, 2, 1, bias=False),
|
| nn.BatchNorm2d(64),
|
| nn.ReLU(inplace=True),
|
| nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
|
| nn.BatchNorm2d(32),
|
| nn.LeakyReLU(inplace=True),
|
| nn.ConvTranspose2d(32, 3, 4, 2, 1, bias=False),
|
| nn.Tanh()
|
| )
|
|
|
| def forward(self, z):
|
| return self.convs(z)
|
|
|
| class ImageAE(nn.Module):
|
|
|
| def __init__(self):
|
| super(ImageAE, self).__init__()
|
| latent_dim = 512
|
| self.enc = VAE_Encoder(latent_dim)
|
| self.dec = VAE_Decoder(latent_dim)
|
|
|
| def forward(self, x):
|
| z, *_ = self.enc(x)
|
| out = self.dec(z)
|
|
|
| return out
|
|
|
| def load_ckpt(self, enc_path, dec_path):
|
| self.enc.load_state_dict(torch.load(enc_path, map_location='cpu'))
|
| self.dec.load_state_dict(torch.load(dec_path, map_location='cpu'))
|
|
|
|
|
| def get_pretraiend_ae(enc_path='pretrained/ae/vae/enc.pth', dec_path='pretrained/ae/vae/dec1.pth'):
|
| ae = ImageAE()
|
| ae.load_ckpt(enc_path, dec_path)
|
| print('load image auto-encoder')
|
| ae.eval()
|
| return ae
|
|
|
|
|
| def get_pretraiend_unet(path='pretrained/ae/unet/ckpt_srm.pth'):
|
| unet = UnetGenerator(3, 3, 8)
|
| unet.load_state_dict(torch.load(path, map_location='cpu'))
|
| print('load Unet')
|
| unet.eval()
|
| return unet
|
|
|
| if __name__ == "__main__":
|
| ae = get_pretraiend_ae()
|
| print(ae)
|
|
|