show / SHOW /modules /DECA /decalib /models /decoders.py
camenduru's picture
thanks to show ❤
3bbb319
# -*- coding: utf-8 -*-
#
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# Using this computer program means that you agree to the terms
# in the LICENSE file included with this software distribution.
# Any use not explicitly granted by the LICENSE is prohibited.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# For comments or questions, please email us at deca@tue.mpg.de
# For commercial licensing contact, please contact ps-license@tuebingen.mpg.de
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, latent_dim=100, out_channels=1, out_scale=0.01, sample_mode = 'bilinear'):
super(Generator, self).__init__()
self.out_scale = out_scale
self.init_size = 32 // 4 # Initial size before upsampling
self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2, mode=sample_mode), #16
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2, mode=sample_mode), #32
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2, mode=sample_mode), #64
nn.Conv2d(64, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2, mode=sample_mode), #128
nn.Conv2d(64, 32, 3, stride=1, padding=1),
nn.BatchNorm2d(32, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2, mode=sample_mode), #256
nn.Conv2d(32, 16, 3, stride=1, padding=1),
nn.BatchNorm2d(16, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(16, out_channels, 3, stride=1, padding=1),
nn.Tanh(),
)
def forward(self, noise):
out = self.l1(noise)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img*self.out_scale