| | |
| |
|
| | import os |
| | import torch |
| | from torch import nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | class PixelLoss(nn.Module): |
| | def __init__(self) -> None: |
| | super(PixelLoss, self).__init__() |
| |
|
| | self.criterion = torch.nn.L1Loss().cuda() |
| |
|
| | def forward(self, gen_hr, org_hr, batch_idx): |
| |
|
| | |
| | pixel_loss = self.criterion(gen_hr, org_hr) |
| |
|
| | return pixel_loss |
| | |
| |
|
| | class L1_Charbonnier_loss(nn.Module): |
| | """L1 Charbonnierloss.""" |
| | def __init__(self): |
| | super(L1_Charbonnier_loss, self).__init__() |
| | self.eps = 1e-6 |
| |
|
| | def forward(self, X, Y, batch_idx): |
| | diff = torch.add(X, -Y) |
| | error = torch.sqrt(diff * diff + self.eps) |
| | loss = torch.mean(error) |
| | return loss |
| | |
| |
|
| |
|
| | """ |
| | Created on Thu Dec 3 00:28:15 2020 |
| | @author: Yunpeng Li, Tianjin University |
| | """ |
| | class MS_SSIM_L1_LOSS(nn.Module): |
| | |
| | def __init__(self, alpha, |
| | gaussian_sigmas=[0.5, 1.0, 2.0, 4.0, 8.0], |
| | data_range = 1.0, |
| | K=(0.01, 0.4), |
| | compensation=1.0, |
| | cuda_dev=0,): |
| | super(MS_SSIM_L1_LOSS, self).__init__() |
| | self.DR = data_range |
| | self.C1 = (K[0] * data_range) ** 2 |
| | self.C2 = (K[1] * data_range) ** 2 |
| | self.pad = int(2 * gaussian_sigmas[-1]) |
| | self.alpha = alpha |
| | self.compensation=compensation |
| | filter_size = int(4 * gaussian_sigmas[-1] + 1) |
| | g_masks = torch.zeros((3*len(gaussian_sigmas), 1, filter_size, filter_size)) |
| | for idx, sigma in enumerate(gaussian_sigmas): |
| | |
| | g_masks[3*idx+0, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma) |
| | g_masks[3*idx+1, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma) |
| | g_masks[3*idx+2, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma) |
| | self.g_masks = g_masks.cuda(cuda_dev) |
| |
|
| | from torch.utils.tensorboard import SummaryWriter |
| | self.writer = SummaryWriter() |
| |
|
| | def _fspecial_gauss_1d(self, size, sigma): |
| | """Create 1-D gauss kernel |
| | Args: |
| | size (int): the size of gauss kernel |
| | sigma (float): sigma of normal distribution |
| | |
| | Returns: |
| | torch.Tensor: 1D kernel (size) |
| | """ |
| | coords = torch.arange(size).to(dtype=torch.float) |
| | coords -= size // 2 |
| | g = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) |
| | g /= g.sum() |
| | return g.reshape(-1) |
| |
|
| | def _fspecial_gauss_2d(self, size, sigma): |
| | """Create 2-D gauss kernel |
| | Args: |
| | size (int): the size of gauss kernel |
| | sigma (float): sigma of normal distribution |
| | |
| | Returns: |
| | torch.Tensor: 2D kernel (size x size) |
| | """ |
| | gaussian_vec = self._fspecial_gauss_1d(size, sigma) |
| | return torch.outer(gaussian_vec, gaussian_vec) |
| |
|
| | def forward(self, x, y, batch_idx): |
| | ''' |
| | Args: |
| | x (tensor): the input for a tensor |
| | y (tensor): the input for another tensor |
| | batch_idx (int): the iteration now |
| | Returns: |
| | combined_loss (torch): loss value of L1 with MS-SSIM loss |
| | ''' |
| |
|
| | |
| | mux = F.conv2d(x, self.g_masks, groups=3, padding=self.pad) |
| | muy = F.conv2d(y, self.g_masks, groups=3, padding=self.pad) |
| |
|
| | mux2 = mux * mux |
| | muy2 = muy * muy |
| | muxy = mux * muy |
| |
|
| | sigmax2 = F.conv2d(x * x, self.g_masks, groups=3, padding=self.pad) - mux2 |
| | sigmay2 = F.conv2d(y * y, self.g_masks, groups=3, padding=self.pad) - muy2 |
| | sigmaxy = F.conv2d(x * y, self.g_masks, groups=3, padding=self.pad) - muxy |
| |
|
| | |
| | l = (2 * muxy + self.C1) / (mux2 + muy2 + self.C1) |
| | cs = (2 * sigmaxy + self.C2) / (sigmax2 + sigmay2 + self.C2) |
| |
|
| | lM = l[:, -1, :, :] * l[:, -2, :, :] * l[:, -3, :, :] |
| | PIcs = cs.prod(dim=1) |
| |
|
| | loss_ms_ssim = 1 - lM*PIcs |
| |
|
| | loss_l1 = F.l1_loss(x, y, reduction='none') |
| | |
| | gaussian_l1 = F.conv2d(loss_l1, self.g_masks.narrow(dim=0, start=-3, length=3), |
| | groups=3, padding=self.pad).mean(1) |
| |
|
| | loss_mix = self.alpha * loss_ms_ssim + (1 - self.alpha) * gaussian_l1 / self.DR |
| | loss_mix = self.compensation*loss_mix |
| |
|
| | combined_loss = loss_mix.mean() |
| | |
| | self.writer.add_scalar('Loss/ms_ssim_loss-iteration', loss_ms_ssim.mean(), batch_idx) |
| | self.writer.add_scalar('Loss/l1_loss-iteration', gaussian_l1.mean(), batch_idx) |
| |
|
| | return combined_loss |
| |
|