| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from torchvision import models |
| | import torch |
| | from torch.nn import init |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import functools |
| | from torch.autograd import grad |
| |
|
| |
|
| | def gradient(inputs, outputs): |
| | d_points = torch.ones_like(outputs, |
| | requires_grad=False, |
| | device=outputs.device) |
| | points_grad = grad(outputs=outputs, |
| | inputs=inputs, |
| | grad_outputs=d_points, |
| | create_graph=True, |
| | retain_graph=True, |
| | only_inputs=True, |
| | allow_unused=True)[0] |
| | return points_grad |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| | def conv3x3(in_planes, |
| | out_planes, |
| | kernel=3, |
| | strd=1, |
| | dilation=1, |
| | padding=1, |
| | bias=False): |
| | "3x3 convolution with padding" |
| | return nn.Conv2d(in_planes, |
| | out_planes, |
| | kernel_size=kernel, |
| | dilation=dilation, |
| | stride=strd, |
| | padding=padding, |
| | bias=bias) |
| |
|
| |
|
| | def conv1x1(in_planes, out_planes, stride=1): |
| | """1x1 convolution""" |
| | return nn.Conv2d(in_planes, |
| | out_planes, |
| | kernel_size=1, |
| | stride=stride, |
| | bias=False) |
| |
|
| |
|
| | def init_weights(net, init_type='normal', init_gain=0.02): |
| | """Initialize network weights. |
| | |
| | Parameters: |
| | net (network) -- network to be initialized |
| | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal |
| | init_gain (float) -- scaling factor for normal, xavier and orthogonal. |
| | |
| | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might |
| | work better for some applications. Feel free to try yourself. |
| | """ |
| | def init_func(m): |
| | classname = m.__class__.__name__ |
| | if hasattr(m, 'weight') and (classname.find('Conv') != -1 |
| | or classname.find('Linear') != -1): |
| | if init_type == 'normal': |
| | init.normal_(m.weight.data, 0.0, init_gain) |
| | elif init_type == 'xavier': |
| | init.xavier_normal_(m.weight.data, gain=init_gain) |
| | elif init_type == 'kaiming': |
| | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') |
| | elif init_type == 'orthogonal': |
| | init.orthogonal_(m.weight.data, gain=init_gain) |
| | else: |
| | raise NotImplementedError( |
| | 'initialization method [%s] is not implemented' % |
| | init_type) |
| | if hasattr(m, 'bias') and m.bias is not None: |
| | init.constant_(m.bias.data, 0.0) |
| | elif classname.find( |
| | 'BatchNorm2d' |
| | ) != -1: |
| | init.normal_(m.weight.data, 1.0, init_gain) |
| | init.constant_(m.bias.data, 0.0) |
| |
|
| | |
| | net.apply(init_func) |
| |
|
| |
|
| | def init_net(net, init_type='xavier', init_gain=0.02, gpu_ids=[]): |
| | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights |
| | Parameters: |
| | net (network) -- the network to be initialized |
| | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal |
| | gain (float) -- scaling factor for normal, xavier and orthogonal. |
| | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 |
| | |
| | Return an initialized network. |
| | """ |
| | if len(gpu_ids) > 0: |
| | assert (torch.cuda.is_available()) |
| | net = torch.nn.DataParallel(net) |
| | init_weights(net, init_type, init_gain=init_gain) |
| | return net |
| |
|
| |
|
| | def imageSpaceRotation(xy, rot): |
| | ''' |
| | args: |
| | xy: (B, 2, N) input |
| | rot: (B, 2) x,y axis rotation angles |
| | |
| | rotation center will be always image center (other rotation center can be represented by additional z translation) |
| | ''' |
| | disp = rot.unsqueeze(2).sin().expand_as(xy) |
| | return (disp * xy).sum(dim=1) |
| |
|
| |
|
| | def cal_gradient_penalty(netD, |
| | real_data, |
| | fake_data, |
| | device, |
| | type='mixed', |
| | constant=1.0, |
| | lambda_gp=10.0): |
| | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 |
| | |
| | Arguments: |
| | netD (network) -- discriminator network |
| | real_data (tensor array) -- real images |
| | fake_data (tensor array) -- generated images from the generator |
| | device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') |
| | type (str) -- if we mix real and fake data or not [real | fake | mixed]. |
| | constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2 |
| | lambda_gp (float) -- weight for this loss |
| | |
| | Returns the gradient penalty loss |
| | """ |
| | if lambda_gp > 0.0: |
| | |
| | if type == 'real': |
| | interpolatesv = real_data |
| | elif type == 'fake': |
| | interpolatesv = fake_data |
| | elif type == 'mixed': |
| | alpha = torch.rand(real_data.shape[0], 1) |
| | alpha = alpha.expand( |
| | real_data.shape[0], |
| | real_data.nelement() // |
| | real_data.shape[0]).contiguous().view(*real_data.shape) |
| | alpha = alpha.to(device) |
| | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) |
| | else: |
| | raise NotImplementedError('{} not implemented'.format(type)) |
| | interpolatesv.requires_grad_(True) |
| | disc_interpolates = netD(interpolatesv) |
| | gradients = torch.autograd.grad( |
| | outputs=disc_interpolates, |
| | inputs=interpolatesv, |
| | grad_outputs=torch.ones(disc_interpolates.size()).to(device), |
| | create_graph=True, |
| | retain_graph=True, |
| | only_inputs=True) |
| | gradients = gradients[0].view(real_data.size(0), -1) |
| | gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** |
| | 2).mean() * lambda_gp |
| | return gradient_penalty, gradients |
| | else: |
| | return 0.0, None |
| |
|
| |
|
| | def get_norm_layer(norm_type='instance'): |
| | """Return a normalization layer |
| | Parameters: |
| | norm_type (str) -- the name of the normalization layer: batch | instance | none |
| | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). |
| | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. |
| | """ |
| | if norm_type == 'batch': |
| | norm_layer = functools.partial(nn.BatchNorm2d, |
| | affine=True, |
| | track_running_stats=True) |
| | elif norm_type == 'instance': |
| | norm_layer = functools.partial(nn.InstanceNorm2d, |
| | affine=False, |
| | track_running_stats=False) |
| | elif norm_type == 'group': |
| | norm_layer = functools.partial(nn.GroupNorm, 32) |
| | elif norm_type == 'none': |
| | norm_layer = None |
| | else: |
| | raise NotImplementedError('normalization layer [%s] is not found' % |
| | norm_type) |
| | return norm_layer |
| |
|
| |
|
| | class Flatten(nn.Module): |
| | def forward(self, input): |
| | return input.view(input.size(0), -1) |
| |
|
| |
|
| | class ConvBlock(nn.Module): |
| | def __init__(self, in_planes, out_planes, opt): |
| | super(ConvBlock, self).__init__() |
| | [k, s, d, p] = opt.conv3x3 |
| | self.conv1 = conv3x3(in_planes, int(out_planes / 2), k, s, d, p) |
| | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4), k, s, d, |
| | p) |
| | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4), k, s, d, |
| | p) |
| |
|
| | if opt.norm == 'batch': |
| | self.bn1 = nn.BatchNorm2d(in_planes) |
| | self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) |
| | self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) |
| | self.bn4 = nn.BatchNorm2d(in_planes) |
| | elif opt.norm == 'group': |
| | self.bn1 = nn.GroupNorm(32, in_planes) |
| | self.bn2 = nn.GroupNorm(32, int(out_planes / 2)) |
| | self.bn3 = nn.GroupNorm(32, int(out_planes / 4)) |
| | self.bn4 = nn.GroupNorm(32, in_planes) |
| |
|
| | if in_planes != out_planes: |
| | self.downsample = nn.Sequential( |
| | self.bn4, |
| | nn.ReLU(True), |
| | nn.Conv2d(in_planes, |
| | out_planes, |
| | kernel_size=1, |
| | stride=1, |
| | bias=False), |
| | ) |
| | else: |
| | self.downsample = None |
| |
|
| | def forward(self, x): |
| | residual = x |
| |
|
| | out1 = self.bn1(x) |
| | out1 = F.relu(out1, True) |
| | out1 = self.conv1(out1) |
| |
|
| | out2 = self.bn2(out1) |
| | out2 = F.relu(out2, True) |
| | out2 = self.conv2(out2) |
| |
|
| | out3 = self.bn3(out2) |
| | out3 = F.relu(out3, True) |
| | out3 = self.conv3(out3) |
| |
|
| | out3 = torch.cat((out1, out2, out3), 1) |
| |
|
| | if self.downsample is not None: |
| | residual = self.downsample(residual) |
| |
|
| | out3 += residual |
| |
|
| | return out3 |
| |
|
| |
|
| | class Vgg19(torch.nn.Module): |
| | def __init__(self, requires_grad=False): |
| | super(Vgg19, self).__init__() |
| | vgg_pretrained_features = models.vgg19(pretrained=True).features |
| | self.slice1 = torch.nn.Sequential() |
| | self.slice2 = torch.nn.Sequential() |
| | self.slice3 = torch.nn.Sequential() |
| | self.slice4 = torch.nn.Sequential() |
| | self.slice5 = torch.nn.Sequential() |
| | for x in range(2): |
| | self.slice1.add_module(str(x), vgg_pretrained_features[x]) |
| | for x in range(2, 7): |
| | self.slice2.add_module(str(x), vgg_pretrained_features[x]) |
| | for x in range(7, 12): |
| | self.slice3.add_module(str(x), vgg_pretrained_features[x]) |
| | for x in range(12, 21): |
| | self.slice4.add_module(str(x), vgg_pretrained_features[x]) |
| | for x in range(21, 30): |
| | self.slice5.add_module(str(x), vgg_pretrained_features[x]) |
| | if not requires_grad: |
| | for param in self.parameters(): |
| | param.requires_grad = False |
| |
|
| | def forward(self, X): |
| | h_relu1 = self.slice1(X) |
| | h_relu2 = self.slice2(h_relu1) |
| | h_relu3 = self.slice3(h_relu2) |
| | h_relu4 = self.slice4(h_relu3) |
| | h_relu5 = self.slice5(h_relu4) |
| | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] |
| | return out |
| |
|
| |
|
| | class VGGLoss(nn.Module): |
| | def __init__(self): |
| | super(VGGLoss, self).__init__() |
| | self.vgg = Vgg19() |
| | self.criterion = nn.L1Loss() |
| | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] |
| |
|
| | def forward(self, x, y): |
| | x_vgg, y_vgg = self.vgg(x), self.vgg(y) |
| | loss = 0 |
| | for i in range(len(x_vgg)): |
| | loss += self.weights[i] * self.criterion(x_vgg[i], |
| | y_vgg[i].detach()) |
| | return loss |
| |
|