| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from models.networks.architecture import VGG19 |
| from torch.nn import init |
| from torchvision import models |
| import numpy as np |
| from pytorch_msssim import ssim, ms_ssim |
| from torchvision.models import vgg16, VGG16_Weights |
|
|
|
|
| class LaplacianLoss(nn.Module): |
| def __init__(self, device): |
| super(LaplacianLoss, self).__init__() |
| self.device = device |
| self.kernel = torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], dtype=torch.float32, device=self.device) |
|
|
| def rgb_to_grayscale(self, x): |
| return (0.299 * x[:, 0, :, :] + 0.587 * x[:, 1, :, :] + 0.114 * x[:, 2, :, :]).unsqueeze(1) |
|
|
| def forward(self, pred, target): |
| pred_gray = self.rgb_to_grayscale(pred) |
| target_gray = self.rgb_to_grayscale(target) |
| pred_lap = F.conv2d(pred_gray, self.kernel, padding=1) |
| target_lap = F.conv2d(target_gray, self.kernel, padding=1) |
| return F.l1_loss(pred_lap, target_lap) |
|
|
| class GANLoss(nn.Module): |
| def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, |
| tensor=torch.FloatTensor, opt=None): |
| super(GANLoss, self).__init__() |
| self.real_label = target_real_label |
| self.fake_label = target_fake_label |
| self.real_label_tensor = None |
| self.fake_label_tensor = None |
| self.zero_tensor = None |
| self.Tensor = tensor |
| self.gan_mode = gan_mode |
| self.opt = opt |
| if gan_mode not in ['ls', 'original', 'w', 'hinge', 'non_saturating']: |
| raise ValueError('Unexpected gan_mode {}'.format(gan_mode)) |
|
|
| def get_target_tensor(self, input, target_is_real): |
| if target_is_real: |
| if self.real_label_tensor is None: |
| self.real_label_tensor = self.Tensor(1).fill_(self.real_label) |
| self.real_label_tensor.requires_grad_(False) |
| return self.real_label_tensor.expand_as(input) |
| else: |
| if self.fake_label_tensor is None: |
| self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label) |
| self.fake_label_tensor.requires_grad_(False) |
| return self.fake_label_tensor.expand_as(input) |
|
|
| def get_zero_tensor(self, input): |
| if self.zero_tensor is None: |
| self.zero_tensor = torch.zeros(1, device='cuda') |
| self.zero_tensor.requires_grad_(False) |
| return self.zero_tensor.expand_as(input) |
|
|
| def loss(self, input, target_is_real, for_discriminator=True): |
| if self.gan_mode == 'non_saturating': |
| if for_discriminator: |
| if target_is_real: |
| return F.softplus(-input).mean() |
| else: |
| return F.softplus(input).mean() |
| else: |
| return F.softplus(-input).mean() |
|
|
| elif self.gan_mode == 'original': |
| target_tensor = self.get_target_tensor(input, target_is_real) |
| loss = F.binary_cross_entropy_with_logits(input, target_tensor) |
| return loss |
| elif self.gan_mode == 'ls': |
| target_tensor = self.get_target_tensor(input, target_is_real) |
| return F.mse_loss(input, target_tensor) |
| elif self.gan_mode == 'hinge': |
| if for_discriminator: |
| if target_is_real: |
| minval = torch.min(input - 1, self.get_zero_tensor(input)) |
| loss = -torch.mean(minval) |
| else: |
| minval = torch.min(-input - 1, self.get_zero_tensor(input)) |
| loss = -torch.mean(minval) |
| else: |
| assert target_is_real, "The generator's hinge loss must be aiming for real" |
| loss = -torch.mean(input) |
| return loss |
| else: |
| if target_is_real: |
| return -input.mean() |
| else: |
| return input.mean() |
|
|
| def __call__(self, input, target_is_real, for_discriminator=True): |
| if isinstance(input, list): |
| loss = 0 |
| for pred_i in input: |
| if isinstance(pred_i, list): |
| pred_i = pred_i[-1] |
| loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) |
| bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) |
| new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) |
| loss += new_loss |
| return loss / len(input) |
| else: |
| return self.loss(input, target_is_real, for_discriminator) |
|
|
|
|
| class VGGLoss(nn.Module): |
| def __init__(self, gpu_ids): |
| super(VGGLoss, self).__init__() |
| self.vgg = VGG19().cuda() |
| self.criterion = nn.L1Loss() |
| self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] |
|
|
| def forward(self, x, y): |
| x_vgg, y_vgg = self.vgg(x), self.vgg(y) |
| loss = 0 |
| for i in range(len(x_vgg)): |
| loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) |
| return loss |
|
|
|
|
| class KLDLoss(nn.Module): |
| def forward(self, mu, logvar): |
| return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) |
|
|
|
|
| class PatchSim(nn.Module): |
| def __init__(self, patch_nums=256, patch_size=None, norm=True): |
| super(PatchSim, self).__init__() |
| self.patch_nums = patch_nums |
| self.patch_size = patch_size |
| self.use_norm = norm |
|
|
| def forward(self, feat, patch_ids=None): |
| B, C, W, H = feat.size() |
| feat = feat - feat.mean(dim=[-2, -1], keepdim=True) |
| feat = F.normalize(feat, dim=1) if self.use_norm else feat / np.sqrt(C) |
| query, key, patch_ids = self.select_patch(feat, patch_ids=patch_ids) |
| patch_sim = query.bmm(key) if self.use_norm else torch.tanh(query.bmm(key) / 10) |
| if patch_ids is not None: |
| patch_sim = patch_sim.view(B, len(patch_ids), -1) |
|
|
| return patch_sim, patch_ids |
|
|
| def select_patch(self, feat, patch_ids=None): |
| B, C, W, H = feat.size() |
| pw, ph = self.patch_size, self.patch_size |
| feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2) |
| if self.patch_nums > 0: |
| if patch_ids is None: |
| patch_ids = torch.randperm(feat_reshape.size(1), device=feat.device) |
| patch_ids = patch_ids[:int(min(self.patch_nums, patch_ids.size(0)))] |
| feat_query = feat_reshape[:, patch_ids, :] |
| feat_key = [] |
| Num = feat_query.size(1) |
| if pw < W and ph < H: |
| pos_x, pos_y = patch_ids // W, patch_ids % W |
| left, top = pos_x - int(pw / 2), pos_y - int(ph / 2) |
| left, top = torch.where(left > 0, left, torch.zeros_like(left)), torch.where(top > 0, top, |
| torch.zeros_like(top)) |
| start_x = torch.where(left > (W - pw), (W - pw) * torch.ones_like(left), left) |
| start_y = torch.where(top > (H - ph), (H - ph) * torch.ones_like(top), top) |
| for i in range(Num): |
| feat_key.append(feat[:, :, start_x[i]:start_x[i] + pw, start_y[i]:start_y[i] + ph]) |
| feat_key = torch.stack(feat_key, dim=0).permute(1, 0, 2, 3, 4) |
| feat_key = feat_key.reshape(B * Num, C, pw * ph) |
| feat_query = feat_query.reshape(B * Num, 1, C) |
| else: |
| feat_key = feat.reshape(B, C, W * H) |
| else: |
| feat_query = feat.reshape(B, C, H * W).permute(0, 2, 1) |
| feat_key = feat.reshape(B, C, H * W) |
|
|
| return feat_query, feat_key, patch_ids |
|
|
|
|
| class SSIMLoss(nn.Module): |
| def __init__(self, data_range=2.0, size_average=True, channel=3, nonnegative_ssim=False): |
| super(SSIMLoss, self).__init__() |
| self.data_range = data_range |
| self.size_average = size_average |
| self.channel = channel |
| self.nonnegative_ssim = nonnegative_ssim |
|
|
| def forward(self, img1, img2): |
| ssim_value = ssim(img1, img2, data_range=self.data_range, size_average=self.size_average, |
| nonnegative_ssim=self.nonnegative_ssim) |
| loss = 1.0 - ssim_value |
| return loss |
|
|
|
|
| class SpatialCorrelativeLoss(nn.Module): |
| def __init__(self, loss_mode='cos', patch_nums=256, patch_size=32, norm=True, use_conv=True, |
| init_type='normal', init_gain=0.02, gpu_ids=[], T=0.1): |
| super(SpatialCorrelativeLoss, self).__init__() |
| self.patch_sim = PatchSim(patch_nums=patch_nums, patch_size=patch_size, norm=norm) |
| self.patch_size = patch_size |
| self.patch_nums = patch_nums |
| self.norm = norm |
| self.use_conv = use_conv |
| self.conv_init = False |
| self.init_type = init_type |
| self.init_gain = init_gain |
| self.gpu_ids = gpu_ids |
| self.loss_mode = loss_mode |
| self.T = T |
| self.criterion = nn.L1Loss() if norm else nn.SmoothL1Loss() |
| self.cross_entropy_loss = nn.CrossEntropyLoss() |
|
|
| def update_init_(self): |
| self.conv_init = True |
|
|
| def create_conv(self, feat, layer): |
| input_nc = feat.size(1) |
| output_nc = max(32, input_nc // 4) |
| conv = nn.Sequential(*[nn.Conv2d(input_nc, output_nc, kernel_size=1), |
| nn.ReLU(), |
| nn.Conv2d(output_nc, output_nc, kernel_size=1)]) |
| conv.to(feat.device) |
| setattr(self, 'conv_%d' % layer, conv) |
| init_net(conv, self.init_type, self.init_gain, self.gpu_ids) |
|
|
| def cal_sim(self, f_src, f_tgt, f_other=None, layer=0, patch_ids=None): |
| if self.use_conv: |
| if not self.conv_init: |
| self.create_conv(f_src, layer) |
| conv = getattr(self, 'conv_%d' % layer) |
| f_src, f_tgt = conv(f_src), conv(f_tgt) |
| f_other = conv(f_other) if f_other is not None else None |
| sim_src, patch_ids = self.patch_sim(f_src, patch_ids) |
| sim_tgt, patch_ids = self.patch_sim(f_tgt, patch_ids) |
| if f_other is not None: |
| sim_other, _ = self.patch_sim(f_other, patch_ids) |
| else: |
| sim_other = None |
|
|
| return sim_src, sim_tgt, sim_other |
|
|
| def compare_sim(self, sim_src, sim_tgt, sim_other): |
| B, Num, N = sim_src.size() |
| if self.loss_mode == 'info' or sim_other is not None: |
| sim_src = F.normalize(sim_src, dim=-1) |
| sim_tgt = F.normalize(sim_tgt, dim=-1) |
| sim_other = F.normalize(sim_other, dim=-1) |
| sam_neg1 = (sim_src.bmm(sim_other.permute(0, 2, 1))).view(-1, Num) / self.T |
| sam_neg2 = (sim_tgt.bmm(sim_other.permute(0, 2, 1))).view(-1, Num) / self.T |
| sam_self = (sim_src.bmm(sim_tgt.permute(0, 2, 1))).view(-1, Num) / self.T |
| sam_self = torch.cat([sam_self, sam_neg1, sam_neg2], dim=-1) |
| loss = self.cross_entropy_loss(sam_self, torch.arange(0, sam_self.size(0), dtype=torch.long, |
| device=sim_src.device) % (Num)) |
| else: |
| tgt_sorted, _ = sim_tgt.sort(dim=-1, descending=True) |
| num = int(N / 4) |
| src = torch.where(sim_tgt < tgt_sorted[:, :, num:num + 1], 0 * sim_src, sim_src) |
| tgt = torch.where(sim_tgt < tgt_sorted[:, :, num:num + 1], 0 * sim_tgt, sim_tgt) |
| if self.loss_mode == 'l1': |
| loss = self.criterion((N / num) * src, (N / num) * tgt) |
| elif self.loss_mode == 'cos': |
| sim_pos = F.cosine_similarity(src, tgt, dim=-1) |
| loss = self.criterion(torch.ones_like(sim_pos), sim_pos) |
| else: |
| raise NotImplementedError('padding [%s] is not implemented' % self.loss_mode) |
|
|
| return loss |
|
|
| def loss(self, f_src, f_tgt, f_other=None, layer=0): |
| sim_src, sim_tgt, sim_other = self.cal_sim(f_src, f_tgt, f_other, layer) |
| loss = self.compare_sim(sim_src, sim_tgt, sim_other) |
| return loss |
|
|
|
|
| class Normalization(nn.Module): |
| def __init__(self, device): |
| super(Normalization, self).__init__() |
| mean = torch.tensor([0.485, 0.456, 0.406]).to(device) |
| std = torch.tensor([0.229, 0.224, 0.225]).to(device) |
| self.mean = mean.view(-1, 1, 1) |
| self.std = std.view(-1, 1, 1) |
|
|
| def forward(self, img): |
| return (img - self.mean) / self.std |
|
|
|
|
| def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True): |
| if len(gpu_ids) > 0: |
| assert (torch.cuda.is_available()) |
| net.to(gpu_ids[0]) |
| if initialize_weights: |
| init_weights(net, init_type, init_gain=init_gain, debug=debug) |
| return net |
|
|
|
|
| def init_weights(net, init_type='normal', init_gain=0.02, debug=False): |
| def init_func(m): |
| classname = m.__class__.__name__ |
| if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): |
| if debug: |
| print(classname) |
| if init_type == 'normal': |
| init.normal_(m.weight.data, 0.0, init_gain) |
| elif init_type == 'xavier': |
| init.xavier_normal_(m.weight.data, gain=init_gain) |
| elif init_type == 'kaiming': |
| init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') |
| elif init_type == 'orthogonal': |
| init.orthogonal_(m.weight.data, gain=init_gain) |
| else: |
| raise NotImplementedError('initialization method [%s] is not implemented' % init_type) |
| if hasattr(m, 'bias') and m.bias is not None: |
| init.constant_(m.bias.data, 0.0) |
| elif classname.find('BatchNorm2d') != -1: |
| init.normal_(m.weight.data, 1.0, init_gain) |
| init.constant_(m.bias.data, 0.0) |
|
|
| print('initialize network with %s' % init_type) |
| net.apply(init_func) |
|
|
|
|
| class VGG16(nn.Module): |
| def __init__(self): |
| super(VGG16, self).__init__() |
| features = vgg16(weights=VGG16_Weights.DEFAULT).features |
| self.relu1_1 = torch.nn.Sequential() |
| self.relu1_2 = torch.nn.Sequential() |
|
|
| self.relu2_1 = torch.nn.Sequential() |
| self.relu2_2 = torch.nn.Sequential() |
|
|
| self.relu3_1 = torch.nn.Sequential() |
| self.relu3_2 = torch.nn.Sequential() |
| self.relu3_3 = torch.nn.Sequential() |
|
|
| self.relu4_1 = torch.nn.Sequential() |
| self.relu4_2 = torch.nn.Sequential() |
| self.relu4_3 = torch.nn.Sequential() |
|
|
| self.relu5_1 = torch.nn.Sequential() |
| self.relu5_2 = torch.nn.Sequential() |
| self.relu5_3 = torch.nn.Sequential() |
|
|
| for x in range(2): |
| self.relu1_1.add_module(str(x), features[x]) |
|
|
| for x in range(2, 4): |
| self.relu1_2.add_module(str(x), features[x]) |
|
|
| for x in range(4, 7): |
| self.relu2_1.add_module(str(x), features[x]) |
|
|
| for x in range(7, 9): |
| self.relu2_2.add_module(str(x), features[x]) |
|
|
| for x in range(9, 12): |
| self.relu3_1.add_module(str(x), features[x]) |
|
|
| for x in range(12, 14): |
| self.relu3_2.add_module(str(x), features[x]) |
|
|
| for x in range(14, 16): |
| self.relu3_3.add_module(str(x), features[x]) |
|
|
| for x in range(16, 18): |
| self.relu4_1.add_module(str(x), features[x]) |
|
|
| for x in range(18, 21): |
| self.relu4_2.add_module(str(x), features[x]) |
|
|
| for x in range(21, 23): |
| self.relu4_3.add_module(str(x), features[x]) |
|
|
| for x in range(23, 26): |
| self.relu5_1.add_module(str(x), features[x]) |
|
|
| for x in range(26, 28): |
| self.relu5_2.add_module(str(x), features[x]) |
|
|
| for x in range(28, 30): |
| self.relu5_3.add_module(str(x), features[x]) |
|
|
| def forward(self, x, layers=None, encode_only=False, resize=False): |
| relu1_1 = self.relu1_1(x) |
| relu1_2 = self.relu1_2(relu1_1) |
|
|
| relu2_1 = self.relu2_1(relu1_2) |
| relu2_2 = self.relu2_2(relu2_1) |
|
|
| relu3_1 = self.relu3_1(relu2_2) |
| relu3_2 = self.relu3_2(relu3_1) |
| relu3_3 = self.relu3_3(relu3_2) |
|
|
| relu4_1 = self.relu4_1(relu3_3) |
| relu4_2 = self.relu4_2(relu4_1) |
| relu4_3 = self.relu4_3(relu4_2) |
|
|
| relu5_1 = self.relu5_1(relu4_3) |
| relu5_2 = self.relu5_2(relu5_1) |
| relu5_3 = self.relu5_3(relu5_2) |
|
|
| out = { |
| 'relu1_1': relu1_1, |
| 'relu1_2': relu1_2, |
|
|
| 'relu2_1': relu2_1, |
| 'relu2_2': relu2_2, |
|
|
| 'relu3_1': relu3_1, |
| 'relu3_2': relu3_2, |
| 'relu3_3': relu3_3, |
|
|
| 'relu4_1': relu4_1, |
| 'relu4_2': relu4_2, |
| 'relu4_3': relu4_3, |
|
|
| 'relu5_1': relu5_1, |
| 'relu5_2': relu5_2, |
| 'relu5_3': relu5_3, |
| } |
| if encode_only: |
| if len(layers) > 0: |
| feats = [] |
| for layer, key in enumerate(out): |
| if layer in layers: |
| feats.append(out[key]) |
| return feats |
| else: |
| return out['relu3_1'] |
| return out |