| import torch |
| import torch.nn as nn |
| import torch.nn.parallel |
| from torch.autograd import Variable |
| import torch.nn.functional as F |
| from torchvision import models |
| import torch.utils.model_zoo as model_zoo |
|
|
| from torch.nn import init |
| import os |
|
|
| import numpy as np |
|
|
|
|
| def weights_init_normal(m): |
| classname = m.__class__.__name__ |
| if classname.find('Conv') != -1: |
| init.normal_(m.weight.data, 0.0, 0.02) |
| elif classname.find('Linear') != -1: |
| init.normal(m.weight.data, 0.0, 0.02) |
| elif classname.find('BatchNorm2d') != -1: |
| init.normal_(m.weight.data, 1.0, 0.02) |
| init.constant_(m.bias.data, 0.0) |
|
|
|
|
| def weights_init_xavier(m): |
| classname = m.__class__.__name__ |
| if classname.find('Conv') != -1: |
| init.xavier_normal_(m.weight.data, gain=0.02) |
| elif classname.find('Linear') != -1: |
| init.xavier_normal_(m.weight.data, gain=0.02) |
| elif classname.find('BatchNorm2d') != -1: |
| init.normal_(m.weight.data, 1.0, 0.02) |
| init.constant_(m.bias.data, 0.0) |
|
|
|
|
| def weights_init_kaiming(m): |
| classname = m.__class__.__name__ |
| if classname.find('Conv') != -1: |
| init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') |
| elif classname.find('Linear') != -1: |
| init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') |
| elif classname.find('BatchNorm2d') != -1: |
| init.normal_(m.weight.data, 1.0, 0.02) |
| init.constant_(m.bias.data, 0.0) |
|
|
|
|
| def init_weights(net, init_type='normal'): |
| print('initialization method [%s]' % init_type) |
| if init_type == 'normal': |
| net.apply(weights_init_normal) |
| elif init_type == 'xavier': |
| net.apply(weights_init_xavier) |
| elif init_type == 'kaiming': |
| net.apply(weights_init_kaiming) |
| else: |
| raise NotImplementedError('initialization method [%s] is not implemented' % init_type) |
|
|
|
|
| class FeatureExtraction(nn.Module): |
| def __init__(self, input_nc, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_dropout=False): |
| super(FeatureExtraction, self).__init__() |
| downconv = nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1) |
| model = [downconv, nn.ReLU(True), norm_layer(ngf)] |
| for i in range(n_layers): |
| in_ngf = 2 ** i * ngf if 2 ** i * ngf < 512 else 512 |
| out_ngf = 2 ** (i + 1) * ngf if 2 ** i * ngf < 512 else 512 |
| downconv = nn.Conv2d(in_ngf, out_ngf, kernel_size=4, stride=2, padding=1) |
| model += [downconv, nn.ReLU(True)] |
| model += [norm_layer(out_ngf)] |
| model += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(True)] |
| model += [norm_layer(512)] |
| model += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(True)] |
|
|
| self.model = nn.Sequential(*model) |
| init_weights(self.model, init_type='normal') |
|
|
| def forward(self, x): |
| return self.model(x) |
|
|
|
|
| class FeatureL2Norm(torch.nn.Module): |
| def __init__(self): |
| super(FeatureL2Norm, self).__init__() |
|
|
| def forward(self, feature): |
| epsilon = 1e-6 |
| norm = torch.pow(torch.sum(torch.pow(feature, 2), 1) + epsilon, 0.5).unsqueeze(1).expand_as(feature) |
| return torch.div(feature, norm) |
|
|
|
|
| class FeatureCorrelation(nn.Module): |
| def __init__(self): |
| super(FeatureCorrelation, self).__init__() |
|
|
| def forward(self, feature_A, feature_B): |
| b, c, h, w = feature_A.size() |
| |
| feature_A = feature_A.transpose(2, 3).contiguous().view(b, c, h * w) |
| feature_B = feature_B.view(b, c, h * w).transpose(1, 2) |
| |
| feature_mul = torch.bmm(feature_B, feature_A) |
| correlation_tensor = feature_mul.view(b, h, w, h * w).transpose(2, 3).transpose(1, 2) |
| return correlation_tensor |
|
|
|
|
| class FeatureRegression(nn.Module): |
| def __init__(self, input_nc=512, output_dim=6, use_cuda=True): |
| super(FeatureRegression, self).__init__() |
| self.conv = nn.Sequential( |
| nn.Conv2d(input_nc, 512, kernel_size=4, stride=2, padding=1), |
| nn.BatchNorm2d(512), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(512, 256, kernel_size=4, stride=2, padding=1), |
| nn.BatchNorm2d(256), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(256, 128, kernel_size=3, padding=1), |
| nn.BatchNorm2d(128), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(128, 64, kernel_size=3, padding=1), |
| nn.BatchNorm2d(64), |
| nn.ReLU(inplace=True), |
| ) |
| self.linear = nn.Linear(64 * 4 * 3, output_dim) |
| self.tanh = nn.Tanh() |
| if use_cuda: |
| self.conv.cuda() |
| self.linear.cuda() |
| self.tanh.cuda() |
|
|
| def forward(self, x): |
| x = self.conv(x) |
| x = x.view(x.size(0), -1) |
| x = self.linear(x) |
| x = self.tanh(x) |
| return x |
|
|
|
|
| class AffineGridGen(nn.Module): |
| def __init__(self, out_h=256, out_w=192, out_ch=3): |
| super(AffineGridGen, self).__init__() |
| self.out_h = out_h |
| self.out_w = out_w |
| self.out_ch = out_ch |
|
|
| def forward(self, theta): |
| theta = theta.contiguous() |
| batch_size = theta.size()[0] |
| out_size = torch.Size((batch_size, self.out_ch, self.out_h, self.out_w)) |
| return F.affine_grid(theta, out_size) |
|
|
|
|
| class TpsGridGen(nn.Module): |
| def __init__(self, out_h=256, out_w=192, use_regular_grid=True, grid_size=3, reg_factor=0, use_cuda=True): |
| super(TpsGridGen, self).__init__() |
| self.out_h, self.out_w = out_h, out_w |
| self.reg_factor = reg_factor |
| self.use_cuda = use_cuda |
|
|
| |
| self.grid = np.zeros([self.out_h, self.out_w, 3], dtype=np.float32) |
| |
| self.grid_X, self.grid_Y = np.meshgrid(np.linspace(-1, 1, out_w), np.linspace(-1, 1, out_h)) |
| |
| self.grid_X = torch.FloatTensor(self.grid_X).unsqueeze(0).unsqueeze(3) |
| self.grid_Y = torch.FloatTensor(self.grid_Y).unsqueeze(0).unsqueeze(3) |
| if use_cuda: |
| self.grid_X = self.grid_X.cuda() |
| self.grid_Y = self.grid_Y.cuda() |
|
|
| |
| if use_regular_grid: |
| axis_coords = np.linspace(-1, 1, grid_size) |
| self.N = grid_size * grid_size |
| P_Y, P_X = np.meshgrid(axis_coords, axis_coords) |
| P_X = np.reshape(P_X, (-1, 1)) |
| P_Y = np.reshape(P_Y, (-1, 1)) |
| P_X = torch.FloatTensor(P_X) |
| P_Y = torch.FloatTensor(P_Y) |
| self.P_X_base = P_X.clone() |
| self.P_Y_base = P_Y.clone() |
| self.Li = self.compute_L_inverse(P_X, P_Y).unsqueeze(0) |
| self.P_X = P_X.unsqueeze(2).unsqueeze(3).unsqueeze(4).transpose(0, 4) |
| self.P_Y = P_Y.unsqueeze(2).unsqueeze(3).unsqueeze(4).transpose(0, 4) |
| if use_cuda: |
| self.P_X = self.P_X.cuda() |
| self.P_Y = self.P_Y.cuda() |
| self.P_X_base = self.P_X_base.cuda() |
| self.P_Y_base = self.P_Y_base.cuda() |
|
|
| def forward(self, theta): |
| warped_grid = self.apply_transformation(theta, torch.cat((self.grid_X, self.grid_Y), 3)) |
|
|
| return warped_grid |
|
|
| def compute_L_inverse(self, X, Y): |
| N = X.size()[0] |
| |
| Xmat = X.expand(N, N) |
| Ymat = Y.expand(N, N) |
| P_dist_squared = torch.pow(Xmat - Xmat.transpose(0, 1), 2) + torch.pow(Ymat - Ymat.transpose(0, 1), 2) |
| P_dist_squared[P_dist_squared == 0] = 1 |
| K = torch.mul(P_dist_squared, torch.log(P_dist_squared)) |
| |
| O = torch.FloatTensor(N, 1).fill_(1) |
| Z = torch.FloatTensor(3, 3).fill_(0) |
| P = torch.cat((O, X, Y), 1) |
| L = torch.cat((torch.cat((K, P), 1), torch.cat((P.transpose(0, 1), Z), 1)), 0) |
| Li = torch.inverse(L) |
| if self.use_cuda: |
| Li = Li.cuda() |
| return Li |
|
|
| def apply_transformation(self, theta, points): |
| if theta.dim() == 2: |
| theta = theta.unsqueeze(2).unsqueeze(3) |
| |
| |
| |
|
|
| |
| batch_size = theta.size()[0] |
| |
| Q_X = theta[:, :self.N, :, :].squeeze(3) |
| Q_Y = theta[:, self.N:, :, :].squeeze(3) |
| Q_X = Q_X + self.P_X_base.expand_as(Q_X) |
| Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y) |
|
|
| |
| points_b = points.size()[0] |
| points_h = points.size()[1] |
| points_w = points.size()[2] |
|
|
| |
| P_X = self.P_X.expand((1, points_h, points_w, 1, self.N)) |
| P_Y = self.P_Y.expand((1, points_h, points_w, 1, self.N)) |
|
|
| |
| W_X = torch.bmm(self.Li[:, :self.N, :self.N].expand((batch_size, self.N, self.N)), Q_X) |
| W_Y = torch.bmm(self.Li[:, :self.N, :self.N].expand((batch_size, self.N, self.N)), Q_Y) |
| |
| |
| W_X = W_X.unsqueeze(3).unsqueeze(4).transpose(1, 4).repeat(1, points_h, points_w, 1, 1) |
| W_Y = W_Y.unsqueeze(3).unsqueeze(4).transpose(1, 4).repeat(1, points_h, points_w, 1, 1) |
| |
| A_X = torch.bmm(self.Li[:, self.N:, :self.N].expand((batch_size, 3, self.N)), Q_X) |
| A_Y = torch.bmm(self.Li[:, self.N:, :self.N].expand((batch_size, 3, self.N)), Q_Y) |
| |
| |
| A_X = A_X.unsqueeze(3).unsqueeze(4).transpose(1, 4).repeat(1, points_h, points_w, 1, 1) |
| A_Y = A_Y.unsqueeze(3).unsqueeze(4).transpose(1, 4).repeat(1, points_h, points_w, 1, 1) |
|
|
| |
| |
| points_X_for_summation = points[:, :, :, 0].unsqueeze(3).unsqueeze(4).expand( |
| points[:, :, :, 0].size() + (1, self.N)) |
| points_Y_for_summation = points[:, :, :, 1].unsqueeze(3).unsqueeze(4).expand( |
| points[:, :, :, 1].size() + (1, self.N)) |
|
|
| if points_b == 1: |
| delta_X = points_X_for_summation - P_X |
| delta_Y = points_Y_for_summation - P_Y |
| else: |
| |
| delta_X = points_X_for_summation - P_X.expand_as(points_X_for_summation) |
| delta_Y = points_Y_for_summation - P_Y.expand_as(points_Y_for_summation) |
|
|
| dist_squared = torch.pow(delta_X, 2) + torch.pow(delta_Y, 2) |
| |
| dist_squared[dist_squared == 0] = 1 |
| U = torch.mul(dist_squared, torch.log(dist_squared)) |
|
|
| |
| points_X_batch = points[:, :, :, 0].unsqueeze(3) |
| points_Y_batch = points[:, :, :, 1].unsqueeze(3) |
| if points_b == 1: |
| points_X_batch = points_X_batch.expand((batch_size,) + points_X_batch.size()[1:]) |
| points_Y_batch = points_Y_batch.expand((batch_size,) + points_Y_batch.size()[1:]) |
|
|
| points_X_prime = A_X[:, :, :, :, 0] + \ |
| torch.mul(A_X[:, :, :, :, 1], points_X_batch) + \ |
| torch.mul(A_X[:, :, :, :, 2], points_Y_batch) + \ |
| torch.sum(torch.mul(W_X, U.expand_as(W_X)), 4) |
|
|
| points_Y_prime = A_Y[:, :, :, :, 0] + \ |
| torch.mul(A_Y[:, :, :, :, 1], points_X_batch) + \ |
| torch.mul(A_Y[:, :, :, :, 2], points_Y_batch) + \ |
| torch.sum(torch.mul(W_Y, U.expand_as(W_Y)), 4) |
|
|
| return torch.cat((points_X_prime, points_Y_prime), 3) |
|
|
|
|
| |
| |
| |
| |
|
|
| class UnetGenerator(nn.Module): |
| def __init__(self, input_nc, output_nc, num_downs, ngf=64, |
| norm_layer=nn.BatchNorm2d, use_dropout=False): |
| super(UnetGenerator, self).__init__() |
| |
| unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, |
| innermost=True) |
| for i in range(num_downs - 5): |
| unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, |
| norm_layer=norm_layer, use_dropout=use_dropout) |
| unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, |
| norm_layer=norm_layer) |
| unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, |
| norm_layer=norm_layer) |
| unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) |
| unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, |
| norm_layer=norm_layer) |
|
|
| self.model = unet_block |
|
|
| def forward(self, input): |
| return self.model(input) |
|
|
|
|
| |
| |
| |
| class UnetSkipConnectionBlock(nn.Module): |
| def __init__(self, outer_nc, inner_nc, input_nc=None, |
| submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): |
| super(UnetSkipConnectionBlock, self).__init__() |
| self.outermost = outermost |
| use_bias = norm_layer == nn.InstanceNorm2d |
|
|
| if input_nc is None: |
| input_nc = outer_nc |
| downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, |
| stride=2, padding=1, bias=use_bias) |
| downrelu = nn.LeakyReLU(0.2, True) |
| uprelu = nn.ReLU(True) |
| if norm_layer != None: |
| downnorm = norm_layer(inner_nc) |
| upnorm = norm_layer(outer_nc) |
|
|
| if outermost: |
| upsample = nn.Upsample(scale_factor=2, mode='bilinear') |
| upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) |
| down = [downconv] |
| |
| up = [uprelu, upsample, upconv] |
| model = down + [submodule] + up |
| elif innermost: |
| upsample = nn.Upsample(scale_factor=2, mode='bilinear') |
| upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) |
| down = [downrelu, downconv] |
| if norm_layer == None: |
| up = [uprelu, upsample, upconv] |
| else: |
| up = [uprelu, upsample, upconv, upnorm] |
| model = down + up |
| else: |
| upsample = nn.Upsample(scale_factor=2, mode='bilinear') |
| upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) |
| if norm_layer == None: |
| down = [downrelu, downconv] |
| up = [uprelu, upsample, upconv] |
| else: |
| down = [downrelu, downconv, downnorm] |
| up = [uprelu, upsample, upconv, upnorm] |
|
|
| if use_dropout: |
| model = down + [submodule] + up + [nn.Dropout(0.5)] |
| else: |
| model = down + [submodule] + up |
|
|
| self.model = nn.Sequential(*model) |
|
|
| def forward(self, x): |
| if self.outermost: |
| return self.model(x) |
| else: |
| return torch.cat([x, self.model(x)], 1) |
|
|
|
|
| |
| class ResidualBlock(nn.Module): |
| def __init__(self, in_features=64, norm_layer=nn.BatchNorm2d): |
| super(ResidualBlock, self).__init__() |
| self.relu = nn.ReLU(True) |
| if norm_layer == None: |
| |
| self.block = nn.Sequential( |
| nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), |
| ) |
| else: |
| self.block = nn.Sequential( |
| nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), |
| norm_layer(in_features), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), |
| norm_layer(in_features) |
| ) |
|
|
| def forward(self, x): |
| residual = x |
| out = self.block(x) |
| out += residual |
| out = self.relu(out) |
| return out |
| |
|
|
|
|
| class ResUnetGenerator(nn.Module): |
| def __init__(self, input_nc, output_nc, num_downs, ngf=64, |
| norm_layer=nn.BatchNorm2d, use_dropout=False): |
| super(ResUnetGenerator, self).__init__() |
| |
| unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, |
| innermost=True) |
|
|
| for i in range(num_downs - 5): |
| unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, |
| norm_layer=norm_layer, use_dropout=use_dropout) |
| unet_block = ResUnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, |
| norm_layer=norm_layer) |
| unet_block = ResUnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, |
| norm_layer=norm_layer) |
| unet_block = ResUnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, |
| norm_layer=norm_layer) |
| unet_block = ResUnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, |
| norm_layer=norm_layer) |
|
|
| self.model = unet_block |
|
|
| def forward(self, input): |
| output = self.model(input) |
|
|
| |
| |
|
|
| return output |
|
|
|
|
| |
| |
| |
| class ResUnetSkipConnectionBlock(nn.Module): |
| def __init__(self, outer_nc, inner_nc, input_nc=None, |
| submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): |
| super(ResUnetSkipConnectionBlock, self).__init__() |
| self.outermost = outermost |
| use_bias = norm_layer == nn.InstanceNorm2d |
|
|
| if input_nc is None: |
| input_nc = outer_nc |
| downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=3, |
| stride=2, padding=1, bias=use_bias) |
| |
| res_downconv = [ResidualBlock(inner_nc, norm_layer), ResidualBlock(inner_nc, norm_layer)] |
| res_upconv = [ResidualBlock(outer_nc, norm_layer), ResidualBlock(outer_nc, norm_layer)] |
|
|
| |
| |
|
|
| downrelu = nn.ReLU(True) |
| uprelu = nn.ReLU(True) |
| if norm_layer != None: |
| downnorm = norm_layer(inner_nc) |
| upnorm = norm_layer(outer_nc) |
|
|
| if outermost: |
| upsample = nn.Upsample(scale_factor=2, mode='nearest') |
| upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) |
| down = [downconv, downrelu] + res_downconv |
| |
| up = [upsample, upconv] |
| model = down + [submodule] + up |
| elif innermost: |
| upsample = nn.Upsample(scale_factor=2, mode='nearest') |
| upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) |
| down = [downconv, downrelu] + res_downconv |
| if norm_layer == None: |
| up = [upsample, upconv, uprelu] + res_upconv |
| else: |
| up = [upsample, upconv, upnorm, uprelu] + res_upconv |
| model = down + up |
| else: |
| upsample = nn.Upsample(scale_factor=2, mode='nearest') |
| upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) |
| if norm_layer == None: |
| down = [downconv, downrelu] + res_downconv |
| up = [upsample, upconv, uprelu] + res_upconv |
| else: |
| down = [downconv, downnorm, downrelu] + res_downconv |
| up = [upsample, upconv, upnorm, uprelu] + res_upconv |
|
|
| if use_dropout: |
| model = down + [submodule] + up + [nn.Dropout(0.5)] |
| else: |
| model = down + [submodule] + up |
|
|
| self.model = nn.Sequential(*model) |
|
|
| def forward(self, x): |
| if self.outermost: |
| return self.model(x) |
| else: |
| return torch.cat([x, self.model(x)], 1) |
|
|
|
|
| class Vgg19(nn.Module): |
| def __init__(self, requires_grad=False): |
| super(Vgg19, self).__init__() |
| vgg_pretrained_features = models.vgg19(pretrained=True).features |
| self.slice1 = nn.Sequential() |
| self.slice2 = nn.Sequential() |
| self.slice3 = nn.Sequential() |
| self.slice4 = nn.Sequential() |
| self.slice5 = nn.Sequential() |
| for x in range(2): |
| self.slice1.add_module(str(x), vgg_pretrained_features[x]) |
| for x in range(2, 7): |
| self.slice2.add_module(str(x), vgg_pretrained_features[x]) |
| for x in range(7, 12): |
| self.slice3.add_module(str(x), vgg_pretrained_features[x]) |
| for x in range(12, 21): |
| self.slice4.add_module(str(x), vgg_pretrained_features[x]) |
| for x in range(21, 30): |
| self.slice5.add_module(str(x), vgg_pretrained_features[x]) |
| if not requires_grad: |
| for param in self.parameters(): |
| param.requires_grad = False |
|
|
| def forward(self, X): |
| h_relu1 = self.slice1(X) |
| h_relu2 = self.slice2(h_relu1) |
| h_relu3 = self.slice3(h_relu2) |
| h_relu4 = self.slice4(h_relu3) |
| h_relu5 = self.slice5(h_relu4) |
| out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] |
| return out |
|
|
| def gram_matrix(input): |
| a, b, c, d = input.size() |
| |
| |
| features = input.view(a * b, c * d) |
| G = torch.mm(features, features.t()) |
| |
| |
| return G.div(a * b * c * d) |
|
|
|
|
| class StyleLoss(nn.Module): |
| def __init__(self): |
| super(StyleLoss, self).__init__() |
|
|
| def forward(self, x, y): |
| Gx = gram_matrix(x) |
| Gy = gram_matrix(y) |
| return F.mse_loss(Gx, Gy) * 30000000 |
|
|
| class VGGLoss(nn.Module): |
| def __init__(self, model=None): |
| super(VGGLoss, self).__init__() |
| if model is None: |
| self.vgg = Vgg19() |
| else: |
| self.vgg = model |
|
|
| self.vgg.cuda() |
| |
| self.criterion = nn.L1Loss() |
| self.style_criterion = StyleLoss() |
| self.weights = [1.0, 1.0, 1.0, 1.0, 1.0] |
| self.style_weights = [1.0, 1.0, 1.0, 1.0, 1.0] |
| |
| |
|
|
| def forward(self, x, y, style=False): |
| x_vgg, y_vgg = self.vgg(x), self.vgg(y) |
| loss = 0 |
| if style: |
| |
| style_loss = 0 |
| for i in range(len(x_vgg)): |
| this_loss = (self.weights[i] * |
| self.criterion(x_vgg[i], y_vgg[i].detach())) |
| this_style_loss = (self.style_weights[i] * |
| self.style_criterion(x_vgg[i], y_vgg[i].detach())) |
| loss += this_loss |
| style_loss += this_style_loss |
| return loss, style_loss |
|
|
| for i in range(len(x_vgg)): |
| this_loss = (self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())) |
| loss += this_loss |
| return loss |
|
|
|
|
| class GMM(nn.Module): |
| """ Geometric Matching Module |
| """ |
|
|
| def __init__(self, opt, input_nc): |
| super(GMM, self).__init__() |
| self.extractionA = FeatureExtraction(input_nc, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d) |
| self.extractionB = FeatureExtraction(3, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d) |
| self.l2norm = FeatureL2Norm() |
| self.correlation = FeatureCorrelation() |
| self.regression = FeatureRegression(input_nc=192, output_dim=2 * opt.grid_size ** 2, use_cuda=True) |
| self.gridGen = TpsGridGen(opt.fine_height, opt.fine_width, use_cuda=True, grid_size=opt.grid_size) |
|
|
| def forward(self, inputA, inputB): |
| featureA = self.extractionA(inputA) |
| featureB = self.extractionB(inputB) |
| featureA = self.l2norm(featureA) |
| featureB = self.l2norm(featureB) |
| correlation = self.correlation(featureA, featureB) |
|
|
| theta = self.regression(correlation) |
| grid = self.gridGen(theta) |
| return grid, theta |
|
|
|
|
| def save_checkpoint(model, save_path): |
| if not os.path.exists(os.path.dirname(save_path)): |
| os.makedirs(os.path.dirname(save_path)) |
| torch.save(model.state_dict(), save_path) |
|
|
|
|
| def load_checkpoint(model, checkpoint_path): |
| if not os.path.exists(checkpoint_path): |
| print('No checkpoint!') |
| return |
|
|
| model.load_state_dict(torch.load(checkpoint_path)) |
|
|
| |
| |
| |
| |
| |
|
|