dxcanh commited on
Commit
adeeb5c
Β·
verified Β·
1 Parent(s): 0c0af96

Upload 3 files

Browse files
basicsr/losses/basic_loss.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from basicsr.archs.vgg_arch import VGGFeatureExtractor
6
+ from basicsr.utils.registry import LOSS_REGISTRY
7
+ from .loss_util import weighted_loss
8
+
9
+ _reduction_modes = ['none', 'mean', 'sum']
10
+
11
+
12
+ @weighted_loss
13
+ def l1_loss(pred, target):
14
+ return F.l1_loss(pred, target, reduction='none')
15
+
16
+
17
+ @weighted_loss
18
+ def mse_loss(pred, target):
19
+ return F.mse_loss(pred, target, reduction='none')
20
+
21
+
22
+ @weighted_loss
23
+ def charbonnier_loss(pred, target, eps=1e-12):
24
+ return torch.sqrt((pred - target)**2 + eps)
25
+
26
+
27
+ @LOSS_REGISTRY.register()
28
+ class L1Loss(nn.Module):
29
+ """L1 (mean absolute error, MAE) loss.
30
+
31
+ Args:
32
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
33
+ reduction (str): Specifies the reduction to apply to the output.
34
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
35
+ """
36
+
37
+ def __init__(self, loss_weight=1.0, reduction='mean'):
38
+ super(L1Loss, self).__init__()
39
+ if reduction not in ['none', 'mean', 'sum']:
40
+ raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
41
+
42
+ self.loss_weight = loss_weight
43
+ self.reduction = reduction
44
+
45
+ def forward(self, pred, target, weight=None, **kwargs):
46
+ """
47
+ Args:
48
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
49
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
50
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
51
+ """
52
+ return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
53
+
54
+
55
+ @LOSS_REGISTRY.register()
56
+ class MSELoss(nn.Module):
57
+ """MSE (L2) loss.
58
+
59
+ Args:
60
+ loss_weight (float): Loss weight for MSE loss. Default: 1.0.
61
+ reduction (str): Specifies the reduction to apply to the output.
62
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
63
+ """
64
+
65
+ def __init__(self, loss_weight=1.0, reduction='mean'):
66
+ super(MSELoss, self).__init__()
67
+ if reduction not in ['none', 'mean', 'sum']:
68
+ raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
69
+
70
+ self.loss_weight = loss_weight
71
+ self.reduction = reduction
72
+
73
+ def forward(self, pred, target, weight=None, **kwargs):
74
+ """
75
+ Args:
76
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
77
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
78
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
79
+ """
80
+ return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
81
+
82
+
83
+ @LOSS_REGISTRY.register()
84
+ class CharbonnierLoss(nn.Module):
85
+ """Charbonnier loss (one variant of Robust L1Loss, a differentiable
86
+ variant of L1Loss).
87
+
88
+ Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
89
+ Super-Resolution".
90
+
91
+ Args:
92
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
93
+ reduction (str): Specifies the reduction to apply to the output.
94
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
95
+ eps (float): A value used to control the curvature near zero. Default: 1e-12.
96
+ """
97
+
98
+ def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
99
+ super(CharbonnierLoss, self).__init__()
100
+ if reduction not in ['none', 'mean', 'sum']:
101
+ raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
102
+
103
+ self.loss_weight = loss_weight
104
+ self.reduction = reduction
105
+ self.eps = eps
106
+
107
+ def forward(self, pred, target, weight=None, **kwargs):
108
+ """
109
+ Args:
110
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
111
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
112
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
113
+ """
114
+ return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
115
+
116
+
117
+ @LOSS_REGISTRY.register()
118
+ class WeightedTVLoss(L1Loss):
119
+ """Weighted TV loss.
120
+
121
+ Args:
122
+ loss_weight (float): Loss weight. Default: 1.0.
123
+ """
124
+
125
+ def __init__(self, loss_weight=1.0, reduction='mean'):
126
+ if reduction not in ['mean', 'sum']:
127
+ raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: mean | sum')
128
+ super(WeightedTVLoss, self).__init__(loss_weight=loss_weight, reduction=reduction)
129
+
130
+ def forward(self, pred, weight=None):
131
+ if weight is None:
132
+ y_weight = None
133
+ x_weight = None
134
+ else:
135
+ y_weight = weight[:, :, :-1, :]
136
+ x_weight = weight[:, :, :, :-1]
137
+
138
+ y_diff = super().forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight)
139
+ x_diff = super().forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight)
140
+
141
+ loss = x_diff + y_diff
142
+
143
+ return loss
144
+
145
+
146
+ @LOSS_REGISTRY.register()
147
+ class PerceptualLoss(nn.Module):
148
+ """Perceptual loss with commonly used style loss.
149
+
150
+ Args:
151
+ layer_weights (dict): The weight for each layer of vgg feature.
152
+ Here is an example: {'conv5_4': 1.}, which means the conv5_4
153
+ feature layer (before relu5_4) will be extracted with weight
154
+ 1.0 in calculating losses.
155
+ vgg_type (str): The type of vgg network used as feature extractor.
156
+ Default: 'vgg19'.
157
+ use_input_norm (bool): If True, normalize the input image in vgg.
158
+ Default: True.
159
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
160
+ Default: False.
161
+ perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
162
+ loss will be calculated and the loss will multiplied by the
163
+ weight. Default: 1.0.
164
+ style_weight (float): If `style_weight > 0`, the style loss will be
165
+ calculated and the loss will multiplied by the weight.
166
+ Default: 0.
167
+ criterion (str): Criterion used for perceptual loss. Default: 'l1'.
168
+ """
169
+
170
+ def __init__(self,
171
+ layer_weights,
172
+ vgg_type='vgg19',
173
+ use_input_norm=True,
174
+ range_norm=False,
175
+ perceptual_weight=1.0,
176
+ style_weight=0.,
177
+ criterion='l1'):
178
+ super(PerceptualLoss, self).__init__()
179
+ self.perceptual_weight = perceptual_weight
180
+ self.style_weight = style_weight
181
+ self.layer_weights = layer_weights
182
+ self.vgg = VGGFeatureExtractor(
183
+ layer_name_list=list(layer_weights.keys()),
184
+ vgg_type=vgg_type,
185
+ use_input_norm=use_input_norm,
186
+ range_norm=range_norm)
187
+
188
+ self.criterion_type = criterion
189
+ if self.criterion_type == 'l1':
190
+ self.criterion = torch.nn.L1Loss()
191
+ elif self.criterion_type == 'l2':
192
+ self.criterion = torch.nn.MSELoss()
193
+ elif self.criterion_type == 'fro':
194
+ self.criterion = None
195
+ else:
196
+ raise NotImplementedError(f'{criterion} criterion has not been supported.')
197
+
198
+ def forward(self, x, gt):
199
+ """Forward function.
200
+
201
+ Args:
202
+ x (Tensor): Input tensor with shape (n, c, h, w).
203
+ gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
204
+
205
+ Returns:
206
+ Tensor: Forward results.
207
+ """
208
+ # extract vgg features
209
+ x_features = self.vgg(x)
210
+ gt_features = self.vgg(gt.detach())
211
+
212
+ # calculate perceptual loss
213
+ if self.perceptual_weight > 0:
214
+ percep_loss = 0
215
+ for k in x_features.keys():
216
+ if self.criterion_type == 'fro':
217
+ percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
218
+ else:
219
+ percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
220
+ percep_loss *= self.perceptual_weight
221
+ else:
222
+ percep_loss = None
223
+
224
+ # calculate style loss
225
+ if self.style_weight > 0:
226
+ style_loss = 0
227
+ for k in x_features.keys():
228
+ if self.criterion_type == 'fro':
229
+ style_loss += torch.norm(
230
+ self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
231
+ else:
232
+ style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
233
+ gt_features[k])) * self.layer_weights[k]
234
+ style_loss *= self.style_weight
235
+ else:
236
+ style_loss = None
237
+
238
+ return percep_loss, style_loss
239
+
240
+ def _gram_mat(self, x):
241
+ """Calculate Gram matrix.
242
+
243
+ Args:
244
+ x (torch.Tensor): Tensor with shape of (n, c, h, w).
245
+
246
+ Returns:
247
+ torch.Tensor: Gram matrix.
248
+ """
249
+ n, c, h, w = x.size()
250
+ features = x.view(n, c, w * h)
251
+ features_t = features.transpose(1, 2)
252
+ gram = features.bmm(features_t) / (c * h * w)
253
+ return gram
basicsr/losses/gan_loss.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import autograd as autograd
4
+ from torch import nn as nn
5
+ from torch.nn import functional as F
6
+
7
+ from basicsr.utils.registry import LOSS_REGISTRY
8
+
9
+
10
+ @LOSS_REGISTRY.register()
11
+ class GANLoss(nn.Module):
12
+ """Define GAN loss.
13
+
14
+ Args:
15
+ gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
16
+ real_label_val (float): The value for real label. Default: 1.0.
17
+ fake_label_val (float): The value for fake label. Default: 0.0.
18
+ loss_weight (float): Loss weight. Default: 1.0.
19
+ Note that loss_weight is only for generators; and it is always 1.0
20
+ for discriminators.
21
+ """
22
+
23
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
24
+ super(GANLoss, self).__init__()
25
+ self.gan_type = gan_type
26
+ self.loss_weight = loss_weight
27
+ self.real_label_val = real_label_val
28
+ self.fake_label_val = fake_label_val
29
+
30
+ if self.gan_type == 'vanilla':
31
+ self.loss = nn.BCEWithLogitsLoss()
32
+ elif self.gan_type == 'lsgan':
33
+ self.loss = nn.MSELoss()
34
+ elif self.gan_type == 'wgan':
35
+ self.loss = self._wgan_loss
36
+ elif self.gan_type == 'wgan_softplus':
37
+ self.loss = self._wgan_softplus_loss
38
+ elif self.gan_type == 'hinge':
39
+ self.loss = nn.ReLU()
40
+ else:
41
+ raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
42
+
43
+ def _wgan_loss(self, input, target):
44
+ """wgan loss.
45
+
46
+ Args:
47
+ input (Tensor): Input tensor.
48
+ target (bool): Target label.
49
+
50
+ Returns:
51
+ Tensor: wgan loss.
52
+ """
53
+ return -input.mean() if target else input.mean()
54
+
55
+ def _wgan_softplus_loss(self, input, target):
56
+ """wgan loss with soft plus. softplus is a smooth approximation to the
57
+ ReLU function.
58
+
59
+ In StyleGAN2, it is called:
60
+ Logistic loss for discriminator;
61
+ Non-saturating loss for generator.
62
+
63
+ Args:
64
+ input (Tensor): Input tensor.
65
+ target (bool): Target label.
66
+
67
+ Returns:
68
+ Tensor: wgan loss.
69
+ """
70
+ return F.softplus(-input).mean() if target else F.softplus(input).mean()
71
+
72
+ def get_target_label(self, input, target_is_real):
73
+ """Get target label.
74
+
75
+ Args:
76
+ input (Tensor): Input tensor.
77
+ target_is_real (bool): Whether the target is real or fake.
78
+
79
+ Returns:
80
+ (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
81
+ return Tensor.
82
+ """
83
+
84
+ if self.gan_type in ['wgan', 'wgan_softplus']:
85
+ return target_is_real
86
+ target_val = (self.real_label_val if target_is_real else self.fake_label_val)
87
+ return input.new_ones(input.size()) * target_val
88
+
89
+ def forward(self, input, target_is_real, is_disc=False):
90
+ """
91
+ Args:
92
+ input (Tensor): The input for the loss module, i.e., the network
93
+ prediction.
94
+ target_is_real (bool): Whether the targe is real or fake.
95
+ is_disc (bool): Whether the loss for discriminators or not.
96
+ Default: False.
97
+
98
+ Returns:
99
+ Tensor: GAN loss value.
100
+ """
101
+ target_label = self.get_target_label(input, target_is_real)
102
+ if self.gan_type == 'hinge':
103
+ if is_disc: # for discriminators in hinge-gan
104
+ input = -input if target_is_real else input
105
+ loss = self.loss(1 + input).mean()
106
+ else: # for generators in hinge-gan
107
+ loss = -input.mean()
108
+ else: # other gan types
109
+ loss = self.loss(input, target_label)
110
+
111
+ # loss_weight is always 1.0 for discriminators
112
+ return loss if is_disc else loss * self.loss_weight
113
+
114
+
115
+ @LOSS_REGISTRY.register()
116
+ class MultiScaleGANLoss(GANLoss):
117
+ """
118
+ MultiScaleGANLoss accepts a list of predictions
119
+ """
120
+
121
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
122
+ super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight)
123
+
124
+ def forward(self, input, target_is_real, is_disc=False):
125
+ """
126
+ The input is a list of tensors, or a list of (a list of tensors)
127
+ """
128
+ if isinstance(input, list):
129
+ loss = 0
130
+ for pred_i in input:
131
+ if isinstance(pred_i, list):
132
+ # Only compute GAN loss for the last layer
133
+ # in case of multiscale feature matching
134
+ pred_i = pred_i[-1]
135
+ # Safe operation: 0-dim tensor calling self.mean() does nothing
136
+ loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean()
137
+ loss += loss_tensor
138
+ return loss / len(input)
139
+ else:
140
+ return super().forward(input, target_is_real, is_disc)
141
+
142
+
143
+ def r1_penalty(real_pred, real_img):
144
+ """R1 regularization for discriminator. The core idea is to
145
+ penalize the gradient on real data alone: when the
146
+ generator distribution produces the true data distribution
147
+ and the discriminator is equal to 0 on the data manifold, the
148
+ gradient penalty ensures that the discriminator cannot create
149
+ a non-zero gradient orthogonal to the data manifold without
150
+ suffering a loss in the GAN game.
151
+
152
+ Reference: Eq. 9 in Which training methods for GANs do actually converge.
153
+ """
154
+ grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
155
+ grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
156
+ return grad_penalty
157
+
158
+
159
+ def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
160
+ noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
161
+ grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
162
+ path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
163
+
164
+ path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
165
+
166
+ path_penalty = (path_lengths - path_mean).pow(2).mean()
167
+
168
+ return path_penalty, path_lengths.detach().mean(), path_mean.detach()
169
+
170
+
171
+ def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
172
+ """Calculate gradient penalty for wgan-gp.
173
+
174
+ Args:
175
+ discriminator (nn.Module): Network for the discriminator.
176
+ real_data (Tensor): Real input data.
177
+ fake_data (Tensor): Fake input data.
178
+ weight (Tensor): Weight tensor. Default: None.
179
+
180
+ Returns:
181
+ Tensor: A tensor for gradient penalty.
182
+ """
183
+
184
+ batch_size = real_data.size(0)
185
+ alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
186
+
187
+ # interpolate between real_data and fake_data
188
+ interpolates = alpha * real_data + (1. - alpha) * fake_data
189
+ interpolates = autograd.Variable(interpolates, requires_grad=True)
190
+
191
+ disc_interpolates = discriminator(interpolates)
192
+ gradients = autograd.grad(
193
+ outputs=disc_interpolates,
194
+ inputs=interpolates,
195
+ grad_outputs=torch.ones_like(disc_interpolates),
196
+ create_graph=True,
197
+ retain_graph=True,
198
+ only_inputs=True)[0]
199
+
200
+ if weight is not None:
201
+ gradients = gradients * weight
202
+
203
+ gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
204
+ if weight is not None:
205
+ gradients_penalty /= torch.mean(weight)
206
+
207
+ return gradients_penalty
basicsr/losses/loss_util.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def reduce_loss(loss, reduction):
7
+ """Reduce loss as specified.
8
+
9
+ Args:
10
+ loss (Tensor): Elementwise loss tensor.
11
+ reduction (str): Options are 'none', 'mean' and 'sum'.
12
+
13
+ Returns:
14
+ Tensor: Reduced loss tensor.
15
+ """
16
+ reduction_enum = F._Reduction.get_enum(reduction)
17
+ # none: 0, elementwise_mean:1, sum: 2
18
+ if reduction_enum == 0:
19
+ return loss
20
+ elif reduction_enum == 1:
21
+ return loss.mean()
22
+ else:
23
+ return loss.sum()
24
+
25
+
26
+ def weight_reduce_loss(loss, weight=None, reduction='mean'):
27
+ """Apply element-wise weight and reduce loss.
28
+
29
+ Args:
30
+ loss (Tensor): Element-wise loss.
31
+ weight (Tensor): Element-wise weights. Default: None.
32
+ reduction (str): Same as built-in losses of PyTorch. Options are
33
+ 'none', 'mean' and 'sum'. Default: 'mean'.
34
+
35
+ Returns:
36
+ Tensor: Loss values.
37
+ """
38
+ # if weight is specified, apply element-wise weight
39
+ if weight is not None:
40
+ assert weight.dim() == loss.dim()
41
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
42
+ loss = loss * weight
43
+
44
+ # if weight is not specified or reduction is sum, just reduce the loss
45
+ if weight is None or reduction == 'sum':
46
+ loss = reduce_loss(loss, reduction)
47
+ # if reduction is mean, then compute mean over weight region
48
+ elif reduction == 'mean':
49
+ if weight.size(1) > 1:
50
+ weight = weight.sum()
51
+ else:
52
+ weight = weight.sum() * loss.size(1)
53
+ loss = loss.sum() / weight
54
+
55
+ return loss
56
+
57
+
58
+ def weighted_loss(loss_func):
59
+ """Create a weighted version of a given loss function.
60
+
61
+ To use this decorator, the loss function must have the signature like
62
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
63
+ element-wise loss without any reduction. This decorator will add weight
64
+ and reduction arguments to the function. The decorated function will have
65
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
66
+ **kwargs)`.
67
+
68
+ :Example:
69
+
70
+ >>> import torch
71
+ >>> @weighted_loss
72
+ >>> def l1_loss(pred, target):
73
+ >>> return (pred - target).abs()
74
+
75
+ >>> pred = torch.Tensor([0, 2, 3])
76
+ >>> target = torch.Tensor([1, 1, 1])
77
+ >>> weight = torch.Tensor([1, 0, 1])
78
+
79
+ >>> l1_loss(pred, target)
80
+ tensor(1.3333)
81
+ >>> l1_loss(pred, target, weight)
82
+ tensor(1.5000)
83
+ >>> l1_loss(pred, target, reduction='none')
84
+ tensor([1., 1., 2.])
85
+ >>> l1_loss(pred, target, weight, reduction='sum')
86
+ tensor(3.)
87
+ """
88
+
89
+ @functools.wraps(loss_func)
90
+ def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
91
+ # get element-wise loss
92
+ loss = loss_func(pred, target, **kwargs)
93
+ loss = weight_reduce_loss(loss, weight, reduction)
94
+ return loss
95
+
96
+ return wrapper
97
+
98
+
99
+ def get_local_weights(residual, ksize):
100
+ """Get local weights for generating the artifact map of LDL.
101
+
102
+ It is only called by the `get_refined_artifact_map` function.
103
+
104
+ Args:
105
+ residual (Tensor): Residual between predicted and ground truth images.
106
+ ksize (Int): size of the local window.
107
+
108
+ Returns:
109
+ Tensor: weight for each pixel to be discriminated as an artifact pixel
110
+ """
111
+
112
+ pad = (ksize - 1) // 2
113
+ residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect')
114
+
115
+ unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1)
116
+ pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1)
117
+
118
+ return pixel_level_weight
119
+
120
+
121
+ def get_refined_artifact_map(img_gt, img_output, img_ema, ksize):
122
+ """Calculate the artifact map of LDL
123
+ (Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022)
124
+
125
+ Args:
126
+ img_gt (Tensor): ground truth images.
127
+ img_output (Tensor): output images given by the optimizing model.
128
+ img_ema (Tensor): output images given by the ema model.
129
+ ksize (Int): size of the local window.
130
+
131
+ Returns:
132
+ overall_weight: weight for each pixel to be discriminated as an artifact pixel
133
+ (calculated based on both local and global observations).
134
+ """
135
+
136
+ residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True)
137
+ residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True)
138
+
139
+ patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5)
140
+ pixel_level_weight = get_local_weights(residual_sr.clone(), ksize)
141
+ overall_weight = patch_level_weight * pixel_level_weight
142
+
143
+ overall_weight[residual_sr < residual_ema] = 0
144
+
145
+ return overall_weight