|
|
from networks import ResnetBlock |
|
|
import functools |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GlobalGenerator(nn.Module): |
|
|
def __init__(self, input_nc=3, output_nc=3, ngf=64, n_downsampling=4, n_blocks=9, norm_layer=functools.partial(nn.InstanceNorm2d, affine=False), |
|
|
padding_type='reflect'): |
|
|
assert(n_blocks >= 0) |
|
|
super(GlobalGenerator, self).__init__() |
|
|
activation = nn.ReLU(True) |
|
|
|
|
|
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation] |
|
|
|
|
|
for i in range(n_downsampling): |
|
|
mult = 2**i |
|
|
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), |
|
|
norm_layer(ngf * mult * 2), activation] |
|
|
|
|
|
|
|
|
mult = 2**n_downsampling |
|
|
for i in range(n_blocks): |
|
|
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)] |
|
|
|
|
|
|
|
|
for i in range(n_downsampling): |
|
|
mult = 2**(n_downsampling - i) |
|
|
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), |
|
|
norm_layer(int(ngf * mult / 2)), activation] |
|
|
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] |
|
|
self.model = nn.Sequential(*model) |
|
|
|
|
|
def forward(self, input): |
|
|
return self.model(input) |
|
|
|