PubAccount commited on
Commit
ead0a80
·
verified ·
1 Parent(s): d68b77d

Update networks/height_head.py

Browse files
Files changed (1) hide show
  1. networks/height_head.py +124 -3
networks/height_head.py CHANGED
@@ -1,10 +1,131 @@
1
  import torch
2
  from torch import nn
3
  from torch.nn import functional as F
4
-
 
 
5
  from mmcv.cnn import ConvModule
6
- from mmseg.models.losses.silog_loss import silog_loss
7
- from mmseg.models.utils import resize
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def nig_nll(gamma, v, alpha, beta, y):
10
  two_beta_lambda = 2 * beta * (1 + v)
 
1
  import torch
2
  from torch import nn
3
  from torch.nn import functional as F
4
+ import warnings
5
+ from typing import Optional, Union
6
+ from torch import Tensor
7
  from mmcv.cnn import ConvModule
8
+
9
+ def reduce_loss(loss, reduction) -> torch.Tensor:
10
+ """Reduce loss as specified.
11
+
12
+ Args:
13
+ loss (Tensor): Elementwise loss tensor.
14
+ reduction (str): Options are "none", "mean" and "sum".
15
+
16
+ Return:
17
+ Tensor: Reduced loss tensor.
18
+ """
19
+ reduction_enum = F._Reduction.get_enum(reduction)
20
+ # none: 0, elementwise_mean:1, sum: 2
21
+ if reduction_enum == 0:
22
+ return loss
23
+ elif reduction_enum == 1:
24
+ return loss.mean()
25
+ elif reduction_enum == 2:
26
+ return loss.sum()
27
+
28
+ def weight_reduce_loss(loss,
29
+ weight=None,
30
+ reduction='mean',
31
+ avg_factor=None) -> torch.Tensor:
32
+ """Apply element-wise weight and reduce loss.
33
+
34
+ Args:
35
+ loss (Tensor): Element-wise loss.
36
+ weight (Tensor): Element-wise weights.
37
+ reduction (str): Same as built-in losses of PyTorch.
38
+ avg_factor (float): Average factor when computing the mean of losses.
39
+
40
+ Returns:
41
+ Tensor: Processed loss values.
42
+ """
43
+ # if weight is specified, apply element-wise weight
44
+ if weight is not None:
45
+ assert weight.dim() == loss.dim()
46
+ if weight.dim() > 1:
47
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
48
+ loss = loss * weight
49
+
50
+ # if avg_factor is not specified, just reduce the loss
51
+ if avg_factor is None:
52
+ loss = reduce_loss(loss, reduction)
53
+ else:
54
+ # if reduction is mean, then average the loss by avg_factor
55
+ if reduction == 'mean':
56
+ # Avoid causing ZeroDivisionError when avg_factor is 0.0,
57
+ # i.e., all labels of an image belong to ignore index.
58
+ eps = torch.finfo(torch.float32).eps
59
+ loss = loss.sum() / (avg_factor + eps)
60
+ # if reduction is 'none', then do nothing, otherwise raise an error
61
+ elif reduction != 'none':
62
+ raise ValueError('avg_factor can not be used with reduction="sum"')
63
+ return loss
64
+
65
+ def resize(input,
66
+ size=None,
67
+ scale_factor=None,
68
+ mode='nearest',
69
+ align_corners=None,
70
+ warning=True):
71
+ if warning:
72
+ if size is not None and align_corners:
73
+ input_h, input_w = tuple(int(x) for x in input.shape[2:])
74
+ output_h, output_w = tuple(int(x) for x in size)
75
+ if output_h > input_h or output_w > output_h:
76
+ if ((output_h > 1 and output_w > 1 and input_h > 1
77
+ and input_w > 1) and (output_h - 1) % (input_h - 1)
78
+ and (output_w - 1) % (input_w - 1)):
79
+ warnings.warn(
80
+ f'When align_corners={align_corners}, '
81
+ 'the output would more aligned if '
82
+ f'input size {(input_h, input_w)} is `x+1` and '
83
+ f'out size {(output_h, output_w)} is `nx+1`')
84
+ return F.interpolate(input, size, scale_factor, mode, align_corners)
85
+
86
+ def silog_loss(pred: Tensor,
87
+ target: Tensor,
88
+ weight: Optional[Tensor] = None,
89
+ eps: float = 1e-4,
90
+ reduction: Union[str, None] = 'mean',
91
+ avg_factor: Optional[int] = None) -> Tensor:
92
+ """Computes the Scale-Invariant Logarithmic (SI-Log) loss between
93
+ prediction and target.
94
+
95
+ Args:
96
+ pred (Tensor): Predicted output.
97
+ target (Tensor): Ground truth.
98
+ weight (Optional[Tensor]): Optional weight to apply on the loss.
99
+ eps (float): Epsilon value to avoid division and log(0).
100
+ reduction (Union[str, None]): Specifies the reduction to apply to the
101
+ output: 'mean', 'sum' or None.
102
+ avg_factor (Optional[int]): Optional average factor for the loss.
103
+
104
+ Returns:
105
+ Tensor: The calculated SI-Log loss.
106
+ """
107
+ pred, target = pred.flatten(1), target.flatten(1)
108
+ valid_mask = (target > eps).detach().float()
109
+
110
+ diff_log = torch.log(target.clamp(min=eps)) - torch.log(
111
+ pred.clamp(min=eps))
112
+
113
+ valid_mask = (target > eps).detach() & (~torch.isnan(diff_log))
114
+ diff_log[~valid_mask] = 0.0
115
+ valid_mask = valid_mask.float()
116
+
117
+ diff_log_sq_mean = (diff_log.pow(2) * valid_mask).sum(
118
+ dim=1) / valid_mask.sum(dim=1).clamp(min=eps)
119
+ diff_log_mean = (diff_log * valid_mask).sum(dim=1) / valid_mask.sum(
120
+ dim=1).clamp(min=eps)
121
+
122
+ loss = torch.sqrt(diff_log_sq_mean - 0.5 * diff_log_mean.pow(2))
123
+
124
+ if weight is not None:
125
+ weight = weight.float()
126
+
127
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
128
+ return loss
129
 
130
  def nig_nll(gamma, v, alpha, beta, y):
131
  two_beta_lambda = 2 * beta * (1 + v)