File size: 1,307 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod

import torch.nn as nn


class BaseWeightedLoss(nn.Module, metaclass=ABCMeta):
    """Base class for loss.



    All subclass should overwrite the ``_forward()`` method which returns the

    normal loss without loss weights.



    Args:

        loss_weight (float): Factor scalar multiplied on the loss.

            Default: 1.0.

    """

    def __init__(self, loss_weight=1.0):
        super().__init__()
        self.loss_weight = loss_weight

    @abstractmethod
    def _forward(self, *args, **kwargs):
        """Forward function."""
        pass

    def forward(self, *args, **kwargs):
        """Defines the computation performed at every call.



        Args:

            *args: The positional arguments for the corresponding

                loss.

            **kwargs: The keyword arguments for the corresponding

                loss.



        Returns:

            torch.Tensor: The calculated loss.

        """
        ret = self._forward(*args, **kwargs)
        if isinstance(ret, dict):
            for k in ret:
                if 'loss' in k:
                    ret[k] *= self.loss_weight
        else:
            ret *= self.loss_weight
        return ret