| | import torch |
| | import torch.nn as nn |
| | from torch.nn import init |
| | import functools |
| | from torch.optim import lr_scheduler |
| | import torch.nn.functional as F |
| | from torch import nn, einsum |
| | from einops import rearrange, reduce, repeat |
| |
|
| | |
| | |
| | |
| |
|
| | class SelfAttention(nn.Module): |
| | """ Self attention Layer""" |
| |
|
| | def __init__(self, input_channel, activation="relu"): |
| | super(SelfAttention, self).__init__() |
| | self.chanel_in = input_channel |
| | self.activation = activation |
| |
|
| | self.query_conv = nn.Conv2d(input_channel, input_channel // 8, 1) |
| | self.key_conv = nn.Conv2d(input_channel, input_channel // 8, 1) |
| | self.value_conv = nn.Conv2d(input_channel, input_channel, 1) |
| | self.gamma = nn.Parameter(torch.zeros(1)) |
| | self.softmax = nn.Softmax(dim=-1) |
| |
|
| | def forward(self, x): |
| | print("Attention Mechanism!") |
| | m_batchsize, C, width, height = x.size() |
| | attention_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) |
| | attention_key = self.key_conv(x).view(m_batchsize, -1, width * height) |
| | energy = torch.bmm(attention_query, attention_key) |
| | attention = self.softmax(energy) |
| | attention_value = self.value_conv(x).view(m_batchsize, -1, width * height) |
| |
|
| | out = torch.bmm(attention_value, attention.permute(0, 2, 1)) |
| | out = out.view(m_batchsize, C, width, height) |
| |
|
| | out = self.gamma * out + x |
| |
|
| | return out |
| |
|
| | def get_norm_layer(norm_type='instance'): |
| | if norm_type == 'batch': |
| | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) |
| | elif norm_type == 'instance': |
| | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) |
| | elif norm_type == 'none': |
| | norm_layer = None |
| | else: |
| | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) |
| | return norm_layer |
| |
|
| |
|
| | def get_scheduler(optimizer, opt): |
| | if opt.lr_policy == 'lambda': |
| | def lambda_rule(epoch): |
| | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) |
| | return lr_l |
| | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) |
| | elif opt.lr_policy == 'step': |
| | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) |
| | elif opt.lr_policy == 'plateau': |
| | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) |
| | elif opt.lr_policy == 'cosine': |
| | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) |
| | else: |
| | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) |
| | return scheduler |
| |
|
| |
|
| | def init_weights(net, init_type='normal', gain=0.02): |
| | def init_func(m): |
| | classname = m.__class__.__name__ |
| | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): |
| | if init_type == 'normal': |
| | init.normal_(m.weight.data, 0.0, gain) |
| | elif init_type == 'xavier': |
| | init.xavier_normal_(m.weight.data, gain=gain) |
| | elif init_type == 'kaiming': |
| | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') |
| | elif init_type == 'orthogonal': |
| | init.orthogonal_(m.weight.data, gain=gain) |
| | else: |
| | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) |
| | if hasattr(m, 'bias') and m.bias is not None: |
| | init.constant_(m.bias.data, 0.0) |
| | elif classname.find('BatchNorm2d') != -1: |
| | init.normal_(m.weight.data, 1.0, gain) |
| | init.constant_(m.bias.data, 0.0) |
| |
|
| | print('initialize network with %s' % init_type) |
| | net.apply(init_func) |
| |
|
| |
|
| | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): |
| | if len(gpu_ids) > 0: |
| | assert(torch.cuda.is_available()) |
| | net.to(gpu_ids[0]) |
| | net = torch.nn.DataParallel(net, gpu_ids) |
| | init_weights(net, init_type, gain=init_gain) |
| | return net |
| |
|
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | class GANLoss(nn.Module): |
| | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): |
| | super(GANLoss, self).__init__() |
| | self.register_buffer('real_label', torch.tensor(target_real_label)) |
| | self.register_buffer('fake_label', torch.tensor(target_fake_label)) |
| | if use_lsgan: |
| | self.loss = nn.MSELoss() |
| | else: |
| | self.loss = nn.BCELoss() |
| | def get_target_tensor(self, input, target_is_real): |
| | if target_is_real: |
| | target_tensor = self.real_label |
| | else: |
| | target_tensor = self.fake_label |
| | return target_tensor.expand_as(input) |
| |
|
| | def __call__(self, input, target_is_real): |
| | target_tensor = self.get_target_tensor(input, target_is_real) |
| | return self.loss(input, target_tensor) |
| |
|
| | |
| | |
| | |
| | class GradPenalty(nn.Module): |
| | def __init__(self, use_cuda): |
| | super(GradPenalty, self).__init__() |
| | self.use_cuda = use_cuda |
| | def forward(self, critic, real_data, fake_data): |
| | alpha = torch.rand_like(real_data) |
| |
|
| | assignGPU = lambda x: x.cuda() if self.use_cuda else x |
| | alpha = assignGPU(alpha) |
| |
|
| | interpolates = alpha*real_data + (1-alpha)*fake_data.detach() |
| | interpolates = assignGPU(interpolates) |
| | interpolates = torch.autograd.Variable(interpolates, requires_grad = True) |
| |
|
| | critic_interpolates = critic(interpolates) |
| |
|
| | gradients = torch.autograd.grad( |
| | outputs=critic_interpolates, |
| | inputs=interpolates, |
| | grad_outputs=assignGPU(torch.ones(critic_interpolates.size())), |
| | create_graph=True, retain_graph=True, only_inputs=True |
| | )[0] |
| | gradients = gradients.view(gradients.size(0), -1) |
| | gradient_penalty = ((gradients.norm(2, dim=1)-1)**2).mean() |
| | return gradient_penalty |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | |
| |
|
| | def exists(val): |
| | return val is not None |
| |
|
| | def default(val, d): |
| | return val if exists(val) else d |
| |
|
| | def l2norm(t): |
| | return F.normalize(t, dim = -1) |
| |
|
| | |
| |
|
| | class Residual(nn.Module): |
| | def __init__(self, fn): |
| | super().__init__() |
| | self.fn = fn |
| |
|
| | def forward(self, x, **kwargs): |
| | return self.fn(x, **kwargs) + x |
| |
|
| | class ChanLayerNorm(nn.Module): |
| | def __init__(self, dim, eps = 1e-5): |
| | super().__init__() |
| | self.eps = eps |
| | self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) |
| | self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) |
| |
|
| | def forward(self, x): |
| | var = torch.var(x, dim = 1, unbiased = False, keepdim = True) |
| | mean = torch.mean(x, dim = 1, keepdim = True) |
| | return (x - mean) / (var + self.eps).sqrt() * self.g + self.b |
| |
|
| | |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| | class ResnetGenerator(nn.Module): |
| | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.InstanceNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', use_attention=False): |
| | assert(n_blocks >= 0) |
| | super(ResnetGenerator, self).__init__() |
| | self.input_nc = input_nc |
| | self.output_nc = output_nc |
| | self.ngf = ngf |
| | if type(norm_layer) == functools.partial: |
| | use_bias = norm_layer.func == nn.InstanceNorm2d |
| | else: |
| | use_bias = norm_layer == nn.InstanceNorm2d |
| |
|
| | model = [ |
| | nn.ReflectionPad2d(3), |
| | nn.Conv2d( |
| | input_nc, ngf, |
| | kernel_size=7, |
| | padding=0, |
| | bias=use_bias |
| | ), |
| | norm_layer(ngf), |
| | nn.ReLU(True) |
| | ] |
| |
|
| | n_downsampling = 2 |
| | for i in range(n_downsampling): |
| | mult = 2**i |
| | model += [ |
| | nn.Conv2d( |
| | ngf * mult, ngf * mult * 2, kernel_size=3, |
| | stride=2, padding=1, bias=use_bias |
| | ), |
| | norm_layer(ngf * mult * 2), |
| | nn.ReLU(True) |
| | ] |
| |
|
| | mult = 2**n_downsampling |
| | for i in range(n_blocks): |
| | model += [ |
| | ResnetBlock( |
| | ngf * mult, |
| | padding_type=padding_type, |
| | norm_layer=norm_layer, |
| | use_dropout=use_dropout, |
| | use_bias=use_bias |
| | ) |
| | ] |
| |
|
| | for i in range(n_downsampling): |
| | mult = 2**(n_downsampling - i) |
| | model += [ |
| | nn.ConvTranspose2d( |
| | ngf * mult, int(ngf * mult / 2), |
| | kernel_size=3, stride=2, |
| | padding=1, output_padding=1, |
| | bias=use_bias |
| | ), |
| | norm_layer(int(ngf * mult / 2)), |
| | nn.ReLU(True) |
| | ] |
| |
|
| | if use_attention and i==0: |
| | model += [SelfAttention(128, 'relu')] |
| |
|
| | model += [nn.ReflectionPad2d(3)] |
| | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] |
| | model += [nn.Tanh()] |
| |
|
| | self.model = nn.Sequential(*model) |
| |
|
| | def forward(self, input): |
| | return self.model(input) |
| |
|
| |
|
| | |
| | class ResnetBlock(nn.Module): |
| | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): |
| | super(ResnetBlock, self).__init__() |
| | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) |
| |
|
| | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): |
| | conv_block = [] |
| | p = 0 |
| | if padding_type == 'reflect': |
| | conv_block += [nn.ReflectionPad2d(1)] |
| | elif padding_type == 'replicate': |
| | conv_block += [nn.ReplicationPad2d(1)] |
| | elif padding_type == 'zero': |
| | p = 1 |
| | else: |
| | raise NotImplementedError('padding [%s] is not implemented' % padding_type) |
| |
|
| | conv_block += [ |
| | nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), |
| | norm_layer(dim), |
| | nn.ReLU(True) |
| | ] |
| | if use_dropout: |
| | conv_block += [nn.Dropout(0.5)] |
| |
|
| | p = 0 |
| | if padding_type == 'reflect': |
| | conv_block += [nn.ReflectionPad2d(1)] |
| | elif padding_type == 'replicate': |
| | conv_block += [nn.ReplicationPad2d(1)] |
| | elif padding_type == 'zero': |
| | p = 1 |
| | else: |
| | raise NotImplementedError('padding [%s] is not implemented' % padding_type) |
| | conv_block += [ |
| | nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), |
| | norm_layer(dim) |
| | ] |
| |
|
| | return nn.Sequential(*conv_block) |
| |
|
| | def forward(self, x): |
| | out = x + self.conv_block(x) |
| | return out |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | class UnetGenerator(nn.Module): |
| | def __init__( |
| | self, |
| | input_nc, |
| | output_nc, |
| | num_downs, ngf=64, |
| | norm_layer=nn.BatchNorm2d, |
| | use_dropout=False |
| | ): |
| | super(UnetGenerator, self).__init__() |
| |
|
| | |
| | unet_block = UnetSkipConnectionBlock( |
| | ngf * 8, |
| | ngf * 8, |
| | input_nc=None, |
| | submodule=None, |
| | norm_layer=norm_layer, |
| | innermost=True |
| | ) |
| | for i in range(num_downs - 5): |
| | unet_block = UnetSkipConnectionBlock( |
| | ngf * 8, ngf * 8, |
| | input_nc=None, |
| | submodule=unet_block, |
| | norm_layer=norm_layer, |
| | use_dropout=use_dropout |
| | ) |
| | unet_block = UnetSkipConnectionBlock( |
| | ngf * 4, ngf * 8, |
| | input_nc=None, |
| | submodule=unet_block, |
| | norm_layer=norm_layer |
| | ) |
| | unet_block = UnetSkipConnectionBlock( |
| | ngf * 2, ngf * 4, |
| | input_nc=None, |
| | submodule=unet_block, |
| | norm_layer=norm_layer |
| | ) |
| | unet_block = UnetSkipConnectionBlock( |
| | ngf, ngf * 2, |
| | input_nc=None, |
| | submodule=unet_block, |
| | norm_layer=norm_layer |
| | ) |
| | unet_block = UnetSkipConnectionBlock( |
| | output_nc, ngf, |
| | input_nc=input_nc, |
| | submodule=unet_block, |
| | outermost=True, |
| | norm_layer=norm_layer |
| | ) |
| |
|
| | self.model = unet_block |
| |
|
| | def forward(self, input): |
| | return self.model(input) |
| |
|
| |
|
| | |
| | |
| | |
| | class UnetSkipConnectionBlock(nn.Module): |
| | def __init__( |
| | self, |
| | outer_nc, |
| | inner_nc, |
| | input_nc=None, |
| | submodule=None, |
| | outermost=False, |
| | innermost=False, |
| | norm_layer=nn.BatchNorm2d, |
| | use_dropout=False |
| | ): |
| | super(UnetSkipConnectionBlock, self).__init__() |
| | self.outermost = outermost |
| | if type(norm_layer) == functools.partial: |
| | use_bias = norm_layer.func == nn.InstanceNorm2d |
| | else: |
| | use_bias = norm_layer == nn.InstanceNorm2d |
| | if input_nc is None: |
| | input_nc = outer_nc |
| | downconv = nn.Conv2d( |
| | input_nc, inner_nc, kernel_size=4, |
| | stride=2, padding=1, bias=use_bias |
| | ) |
| | downrelu = nn.LeakyReLU(0.2, True) |
| | downnorm = norm_layer(inner_nc) |
| | uprelu = nn.ReLU(True) |
| | upnorm = norm_layer(outer_nc) |
| |
|
| | if outermost: |
| | upconv = nn.ConvTranspose2d( |
| | inner_nc * 2, outer_nc, |
| | kernel_size=4, stride=2, |
| | padding=1 |
| | ) |
| | down = [downconv] |
| | up = [uprelu, upconv, nn.Tanh()] |
| | model = down + [submodule] + up |
| | elif innermost: |
| | upconv = nn.ConvTranspose2d( |
| | inner_nc, outer_nc, |
| | kernel_size=4, stride=2, |
| | padding=1, bias=use_bias |
| | ) |
| | down = [downrelu, downconv] |
| | up = [uprelu, upconv, upnorm] |
| | model = down + up |
| | else: |
| | upconv = nn.ConvTranspose2d( |
| | inner_nc * 2, outer_nc, |
| | kernel_size=4, stride=2, |
| | padding=1, bias=use_bias |
| | ) |
| | down = [downrelu, downconv, downnorm] |
| | up = [uprelu, upconv, upnorm] |
| |
|
| | if use_dropout: |
| | model = down + [submodule] + up + [nn.Dropout(0.5)] |
| | else: |
| | model = down + [submodule] + up |
| |
|
| | self.model = nn.Sequential(*model) |
| |
|
| | def forward(self, x): |
| | if self.outermost: |
| | return self.model(x) |
| | else: |
| | return torch.cat([x, self.model(x)], 1) |
| |
|
| |
|
| | |
| | class NLayerDiscriminator(nn.Module): |
| | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, use_attention=False): |
| | super(NLayerDiscriminator, self).__init__() |
| | if type(norm_layer) == functools.partial: |
| | use_bias = norm_layer.func == nn.InstanceNorm2d |
| | else: |
| | use_bias = norm_layer == nn.InstanceNorm2d |
| |
|
| | kw = 4 |
| | padw = 1 |
| | sequence = [ |
| | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), |
| | nn.LeakyReLU(0.2, True) |
| | ] |
| |
|
| | nf_mult = 1 |
| | nf_mult_prev = 1 |
| | for n in range(1, n_layers): |
| | nf_mult_prev = nf_mult |
| | nf_mult = min(2**n, 8) |
| | sequence += [ |
| | nn.Conv2d( |
| | ndf * nf_mult_prev, ndf * nf_mult, |
| | kernel_size=kw, stride=2, padding=padw, bias=use_bias |
| | ), |
| | norm_layer(ndf * nf_mult), |
| | nn.LeakyReLU(0.2, True) |
| | ] |
| |
|
| | nf_mult_prev = nf_mult |
| | nf_mult = min(2**n_layers, 8) |
| | sequence += [ |
| | nn.Conv2d( |
| | ndf * nf_mult_prev, ndf * nf_mult, |
| | kernel_size=kw, stride=1, |
| | padding=padw, bias=use_bias |
| | ), |
| | norm_layer(ndf * nf_mult), |
| | nn.LeakyReLU(0.2, True) |
| | ] |
| | if use_attention: |
| | sequence += [SelfAttention(512, 'relu')] |
| | sequence += [ |
| | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) |
| | ] |
| |
|
| | if use_sigmoid: |
| | sequence += [nn.Sigmoid()] |
| |
|
| | self.model = nn.Sequential(*sequence) |
| |
|
| | def forward(self, input): |
| | return self.model(input) |
| |
|
| | class NLayerDiscriminatorSN(nn.Module): |
| | def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, use_attention=False): |
| | super(NLayerDiscriminatorSN, self).__init__() |
| | use_bias = False |
| |
|
| | kw = 4 |
| | padw = 1 |
| | sequence = [ |
| | SpectralNorm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)), |
| | nn.LeakyReLU(0.2, True) |
| | ] |
| |
|
| | nf_mult = 1 |
| | nf_mult_prev = 1 |
| | for n in range(1, n_layers): |
| | nf_mult_prev = nf_mult |
| | nf_mult = min(2**n, 8) |
| | sequence += [ |
| | SpectralNorm( |
| | nn.Conv2d( |
| | ndf * nf_mult_prev, |
| | ndf * nf_mult, |
| | kernel_size=kw, stride=2, |
| | padding=padw, bias=use_bias |
| | ) |
| | ), |
| | nn.LeakyReLU(0.2, True) |
| | ] |
| |
|
| | nf_mult_prev = nf_mult |
| | nf_mult = min(2**n_layers, 8) |
| | sequence += [ |
| | SpectralNorm( |
| | nn.Conv2d( |
| | ndf * nf_mult_prev, ndf * nf_mult, |
| | kernel_size=kw, stride=1, padding=padw, bias=use_bias |
| | ) |
| | ), |
| | nn.LeakyReLU(0.2, True) |
| | ] |
| | if use_attention: |
| | sequence += [SelfAttention(512, 'relu')] |
| | sequence += [SpectralNorm(nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw))] |
| |
|
| | if use_sigmoid: |
| | sequence += [nn.Sigmoid()] |
| |
|
| | self.model = nn.Sequential(*sequence) |
| |
|
| | def forward(self, input): |
| | return self.model(input) |
| |
|
| | class PixelDiscriminator(nn.Module): |
| | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False): |
| | super(PixelDiscriminator, self).__init__() |
| | if type(norm_layer) == functools.partial: |
| | use_bias = norm_layer.func == nn.InstanceNorm2d |
| | else: |
| | use_bias = norm_layer == nn.InstanceNorm2d |
| |
|
| | self.net = [ |
| | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), |
| | nn.LeakyReLU(0.2, True), |
| | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), |
| | norm_layer(ndf * 2), |
| | nn.LeakyReLU(0.2, True), |
| | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] |
| |
|
| | if use_sigmoid: |
| | self.net.append(nn.Sigmoid()) |
| |
|
| | self.net = nn.Sequential(*self.net) |
| |
|
| | def forward(self, input): |
| | return self.net(input) |
| |
|