CR-Net / models /networks /loss.py
datnguyentien204's picture
Upload 147 files
0f52c9d verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.networks.architecture import VGG19
from torch.nn import init
from torchvision import models
import numpy as np
from pytorch_msssim import ssim, ms_ssim
from torchvision.models import vgg16, VGG16_Weights
class LaplacianLoss(nn.Module):
def __init__(self, device):
super(LaplacianLoss, self).__init__()
self.device = device
self.kernel = torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], dtype=torch.float32, device=self.device)
def rgb_to_grayscale(self, x):
return (0.299 * x[:, 0, :, :] + 0.587 * x[:, 1, :, :] + 0.114 * x[:, 2, :, :]).unsqueeze(1)
def forward(self, pred, target):
pred_gray = self.rgb_to_grayscale(pred)
target_gray = self.rgb_to_grayscale(target)
pred_lap = F.conv2d(pred_gray, self.kernel, padding=1)
target_lap = F.conv2d(target_gray, self.kernel, padding=1)
return F.l1_loss(pred_lap, target_lap)
class GANLoss(nn.Module):
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor, opt=None):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_tensor = None
self.fake_label_tensor = None
self.zero_tensor = None
self.Tensor = tensor
self.gan_mode = gan_mode
self.opt = opt
if gan_mode not in ['ls', 'original', 'w', 'hinge', 'non_saturating']:
raise ValueError('Unexpected gan_mode {}'.format(gan_mode))
def get_target_tensor(self, input, target_is_real):
if target_is_real:
if self.real_label_tensor is None:
self.real_label_tensor = self.Tensor(1).fill_(self.real_label)
self.real_label_tensor.requires_grad_(False)
return self.real_label_tensor.expand_as(input)
else:
if self.fake_label_tensor is None:
self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label)
self.fake_label_tensor.requires_grad_(False)
return self.fake_label_tensor.expand_as(input)
def get_zero_tensor(self, input):
if self.zero_tensor is None:
self.zero_tensor = torch.zeros(1, device='cuda')
self.zero_tensor.requires_grad_(False)
return self.zero_tensor.expand_as(input)
def loss(self, input, target_is_real, for_discriminator=True):
if self.gan_mode == 'non_saturating':
if for_discriminator:
if target_is_real:
return F.softplus(-input).mean()
else:
return F.softplus(input).mean()
else:
return F.softplus(-input).mean()
elif self.gan_mode == 'original':
target_tensor = self.get_target_tensor(input, target_is_real)
loss = F.binary_cross_entropy_with_logits(input, target_tensor)
return loss
elif self.gan_mode == 'ls':
target_tensor = self.get_target_tensor(input, target_is_real)
return F.mse_loss(input, target_tensor)
elif self.gan_mode == 'hinge':
if for_discriminator:
if target_is_real:
minval = torch.min(input - 1, self.get_zero_tensor(input))
loss = -torch.mean(minval)
else:
minval = torch.min(-input - 1, self.get_zero_tensor(input))
loss = -torch.mean(minval)
else:
assert target_is_real, "The generator's hinge loss must be aiming for real"
loss = -torch.mean(input)
return loss
else:
if target_is_real:
return -input.mean()
else:
return input.mean()
def __call__(self, input, target_is_real, for_discriminator=True):
if isinstance(input, list):
loss = 0
for pred_i in input:
if isinstance(pred_i, list):
pred_i = pred_i[-1]
loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
loss += new_loss
return loss / len(input)
else:
return self.loss(input, target_is_real, for_discriminator)
class VGGLoss(nn.Module):
def __init__(self, gpu_ids):
super(VGGLoss, self).__init__()
self.vgg = VGG19().cuda()
self.criterion = nn.L1Loss()
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
for i in range(len(x_vgg)):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss
class KLDLoss(nn.Module):
def forward(self, mu, logvar):
return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
class PatchSim(nn.Module):
def __init__(self, patch_nums=256, patch_size=None, norm=True):
super(PatchSim, self).__init__()
self.patch_nums = patch_nums
self.patch_size = patch_size
self.use_norm = norm
def forward(self, feat, patch_ids=None):
B, C, W, H = feat.size()
feat = feat - feat.mean(dim=[-2, -1], keepdim=True)
feat = F.normalize(feat, dim=1) if self.use_norm else feat / np.sqrt(C)
query, key, patch_ids = self.select_patch(feat, patch_ids=patch_ids)
patch_sim = query.bmm(key) if self.use_norm else torch.tanh(query.bmm(key) / 10)
if patch_ids is not None:
patch_sim = patch_sim.view(B, len(patch_ids), -1)
return patch_sim, patch_ids
def select_patch(self, feat, patch_ids=None):
B, C, W, H = feat.size()
pw, ph = self.patch_size, self.patch_size
feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
if self.patch_nums > 0:
if patch_ids is None:
patch_ids = torch.randperm(feat_reshape.size(1), device=feat.device)
patch_ids = patch_ids[:int(min(self.patch_nums, patch_ids.size(0)))]
feat_query = feat_reshape[:, patch_ids, :]
feat_key = []
Num = feat_query.size(1)
if pw < W and ph < H:
pos_x, pos_y = patch_ids // W, patch_ids % W
left, top = pos_x - int(pw / 2), pos_y - int(ph / 2)
left, top = torch.where(left > 0, left, torch.zeros_like(left)), torch.where(top > 0, top,
torch.zeros_like(top))
start_x = torch.where(left > (W - pw), (W - pw) * torch.ones_like(left), left)
start_y = torch.where(top > (H - ph), (H - ph) * torch.ones_like(top), top)
for i in range(Num):
feat_key.append(feat[:, :, start_x[i]:start_x[i] + pw, start_y[i]:start_y[i] + ph])
feat_key = torch.stack(feat_key, dim=0).permute(1, 0, 2, 3, 4)
feat_key = feat_key.reshape(B * Num, C, pw * ph)
feat_query = feat_query.reshape(B * Num, 1, C)
else:
feat_key = feat.reshape(B, C, W * H)
else:
feat_query = feat.reshape(B, C, H * W).permute(0, 2, 1)
feat_key = feat.reshape(B, C, H * W)
return feat_query, feat_key, patch_ids
class SSIMLoss(nn.Module):
def __init__(self, data_range=2.0, size_average=True, channel=3, nonnegative_ssim=False):
super(SSIMLoss, self).__init__()
self.data_range = data_range
self.size_average = size_average
self.channel = channel
self.nonnegative_ssim = nonnegative_ssim
def forward(self, img1, img2):
ssim_value = ssim(img1, img2, data_range=self.data_range, size_average=self.size_average,
nonnegative_ssim=self.nonnegative_ssim)
loss = 1.0 - ssim_value
return loss
class SpatialCorrelativeLoss(nn.Module):
def __init__(self, loss_mode='cos', patch_nums=256, patch_size=32, norm=True, use_conv=True,
init_type='normal', init_gain=0.02, gpu_ids=[], T=0.1):
super(SpatialCorrelativeLoss, self).__init__()
self.patch_sim = PatchSim(patch_nums=patch_nums, patch_size=patch_size, norm=norm)
self.patch_size = patch_size
self.patch_nums = patch_nums
self.norm = norm
self.use_conv = use_conv
self.conv_init = False
self.init_type = init_type
self.init_gain = init_gain
self.gpu_ids = gpu_ids
self.loss_mode = loss_mode
self.T = T
self.criterion = nn.L1Loss() if norm else nn.SmoothL1Loss()
self.cross_entropy_loss = nn.CrossEntropyLoss()
def update_init_(self):
self.conv_init = True
def create_conv(self, feat, layer):
input_nc = feat.size(1)
output_nc = max(32, input_nc // 4)
conv = nn.Sequential(*[nn.Conv2d(input_nc, output_nc, kernel_size=1),
nn.ReLU(),
nn.Conv2d(output_nc, output_nc, kernel_size=1)])
conv.to(feat.device)
setattr(self, 'conv_%d' % layer, conv)
init_net(conv, self.init_type, self.init_gain, self.gpu_ids)
def cal_sim(self, f_src, f_tgt, f_other=None, layer=0, patch_ids=None):
if self.use_conv:
if not self.conv_init:
self.create_conv(f_src, layer)
conv = getattr(self, 'conv_%d' % layer)
f_src, f_tgt = conv(f_src), conv(f_tgt)
f_other = conv(f_other) if f_other is not None else None
sim_src, patch_ids = self.patch_sim(f_src, patch_ids)
sim_tgt, patch_ids = self.patch_sim(f_tgt, patch_ids)
if f_other is not None:
sim_other, _ = self.patch_sim(f_other, patch_ids)
else:
sim_other = None
return sim_src, sim_tgt, sim_other
def compare_sim(self, sim_src, sim_tgt, sim_other):
B, Num, N = sim_src.size()
if self.loss_mode == 'info' or sim_other is not None:
sim_src = F.normalize(sim_src, dim=-1)
sim_tgt = F.normalize(sim_tgt, dim=-1)
sim_other = F.normalize(sim_other, dim=-1)
sam_neg1 = (sim_src.bmm(sim_other.permute(0, 2, 1))).view(-1, Num) / self.T
sam_neg2 = (sim_tgt.bmm(sim_other.permute(0, 2, 1))).view(-1, Num) / self.T
sam_self = (sim_src.bmm(sim_tgt.permute(0, 2, 1))).view(-1, Num) / self.T
sam_self = torch.cat([sam_self, sam_neg1, sam_neg2], dim=-1)
loss = self.cross_entropy_loss(sam_self, torch.arange(0, sam_self.size(0), dtype=torch.long,
device=sim_src.device) % (Num))
else:
tgt_sorted, _ = sim_tgt.sort(dim=-1, descending=True)
num = int(N / 4)
src = torch.where(sim_tgt < tgt_sorted[:, :, num:num + 1], 0 * sim_src, sim_src)
tgt = torch.where(sim_tgt < tgt_sorted[:, :, num:num + 1], 0 * sim_tgt, sim_tgt)
if self.loss_mode == 'l1':
loss = self.criterion((N / num) * src, (N / num) * tgt)
elif self.loss_mode == 'cos':
sim_pos = F.cosine_similarity(src, tgt, dim=-1)
loss = self.criterion(torch.ones_like(sim_pos), sim_pos)
else:
raise NotImplementedError('padding [%s] is not implemented' % self.loss_mode)
return loss
def loss(self, f_src, f_tgt, f_other=None, layer=0):
sim_src, sim_tgt, sim_other = self.cal_sim(f_src, f_tgt, f_other, layer)
loss = self.compare_sim(sim_src, sim_tgt, sim_other)
return loss
class Normalization(nn.Module):
def __init__(self, device):
super(Normalization, self).__init__()
mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
std = torch.tensor([0.229, 0.224, 0.225]).to(device)
self.mean = mean.view(-1, 1, 1)
self.std = std.view(-1, 1, 1)
def forward(self, img):
return (img - self.mean) / self.std
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True):
if len(gpu_ids) > 0:
assert (torch.cuda.is_available())
net.to(gpu_ids[0])
if initialize_weights:
init_weights(net, init_type, init_gain=init_gain, debug=debug)
return net
def init_weights(net, init_type='normal', init_gain=0.02, debug=False):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if debug:
print(classname)
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, init_gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
class VGG16(nn.Module):
def __init__(self):
super(VGG16, self).__init__()
features = vgg16(weights=VGG16_Weights.DEFAULT).features
self.relu1_1 = torch.nn.Sequential()
self.relu1_2 = torch.nn.Sequential()
self.relu2_1 = torch.nn.Sequential()
self.relu2_2 = torch.nn.Sequential()
self.relu3_1 = torch.nn.Sequential()
self.relu3_2 = torch.nn.Sequential()
self.relu3_3 = torch.nn.Sequential()
self.relu4_1 = torch.nn.Sequential()
self.relu4_2 = torch.nn.Sequential()
self.relu4_3 = torch.nn.Sequential()
self.relu5_1 = torch.nn.Sequential()
self.relu5_2 = torch.nn.Sequential()
self.relu5_3 = torch.nn.Sequential()
for x in range(2):
self.relu1_1.add_module(str(x), features[x])
for x in range(2, 4):
self.relu1_2.add_module(str(x), features[x])
for x in range(4, 7):
self.relu2_1.add_module(str(x), features[x])
for x in range(7, 9):
self.relu2_2.add_module(str(x), features[x])
for x in range(9, 12):
self.relu3_1.add_module(str(x), features[x])
for x in range(12, 14):
self.relu3_2.add_module(str(x), features[x])
for x in range(14, 16):
self.relu3_3.add_module(str(x), features[x])
for x in range(16, 18):
self.relu4_1.add_module(str(x), features[x])
for x in range(18, 21):
self.relu4_2.add_module(str(x), features[x])
for x in range(21, 23):
self.relu4_3.add_module(str(x), features[x])
for x in range(23, 26):
self.relu5_1.add_module(str(x), features[x])
for x in range(26, 28):
self.relu5_2.add_module(str(x), features[x])
for x in range(28, 30):
self.relu5_3.add_module(str(x), features[x])
def forward(self, x, layers=None, encode_only=False, resize=False):
relu1_1 = self.relu1_1(x)
relu1_2 = self.relu1_2(relu1_1)
relu2_1 = self.relu2_1(relu1_2)
relu2_2 = self.relu2_2(relu2_1)
relu3_1 = self.relu3_1(relu2_2)
relu3_2 = self.relu3_2(relu3_1)
relu3_3 = self.relu3_3(relu3_2)
relu4_1 = self.relu4_1(relu3_3)
relu4_2 = self.relu4_2(relu4_1)
relu4_3 = self.relu4_3(relu4_2)
relu5_1 = self.relu5_1(relu4_3)
relu5_2 = self.relu5_2(relu5_1)
relu5_3 = self.relu5_3(relu5_2)
out = {
'relu1_1': relu1_1,
'relu1_2': relu1_2,
'relu2_1': relu2_1,
'relu2_2': relu2_2,
'relu3_1': relu3_1,
'relu3_2': relu3_2,
'relu3_3': relu3_3,
'relu4_1': relu4_1,
'relu4_2': relu4_2,
'relu4_3': relu4_3,
'relu5_1': relu5_1,
'relu5_2': relu5_2,
'relu5_3': relu5_3,
}
if encode_only:
if len(layers) > 0:
feats = []
for layer, key in enumerate(out):
if layer in layers:
feats.append(out[key])
return feats
else:
return out['relu3_1']
return out