File size: 6,958 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from mmengine.device import get_device

from mmaction.registry import MODELS
from .base import BaseWeightedLoss


@MODELS.register_module()
class HVULoss(BaseWeightedLoss):
    """Calculate the BCELoss for HVU.



    Args:

        categories (tuple[str]): Names of tag categories, tags are organized in

            this order. Default: ['action', 'attribute', 'concept', 'event',

            'object', 'scene'].

        category_nums (tuple[int]): Number of tags for each category. Default:

            (739, 117, 291, 69, 1678, 248).

        category_loss_weights (tuple[float]): Loss weights of categories, it

            applies only if `loss_type == 'individual'`. The loss weights will

            be normalized so that the sum equals to 1, so that you can give any

            positive number as loss weight. Default: (1, 1, 1, 1, 1, 1).

        loss_type (str): The loss type we calculate, we can either calculate

            the BCELoss for all tags, or calculate the BCELoss for tags in each

            category. Choices are 'individual' or 'all'. Default: 'all'.

        with_mask (bool): Since some tag categories are missing for some video

            clips. If `with_mask == True`, we will not calculate loss for these

            missing categories. Otherwise, these missing categories are treated

            as negative samples.

        reduction (str): Reduction way. Choices are 'mean' or 'sum'. Default:

            'mean'.

        loss_weight (float): The loss weight. Default: 1.0.

    """

    def __init__(self,

                 categories=('action', 'attribute', 'concept', 'event',

                             'object', 'scene'),

                 category_nums=(739, 117, 291, 69, 1678, 248),

                 category_loss_weights=(1, 1, 1, 1, 1, 1),

                 loss_type='all',

                 with_mask=False,

                 reduction='mean',

                 loss_weight=1.0):

        super().__init__(loss_weight)
        self.categories = categories
        self.category_nums = category_nums
        self.category_loss_weights = category_loss_weights
        assert len(self.category_nums) == len(self.category_loss_weights)
        for category_loss_weight in self.category_loss_weights:
            assert category_loss_weight >= 0
        self.loss_type = loss_type
        self.with_mask = with_mask
        self.reduction = reduction
        self.category_startidx = [0]
        for i in range(len(self.category_nums) - 1):
            self.category_startidx.append(self.category_startidx[-1] +
                                          self.category_nums[i])
        assert self.loss_type in ['individual', 'all']
        assert self.reduction in ['mean', 'sum']

    def _forward(self, cls_score, label, mask, category_mask):
        """Forward function.



        Args:

            cls_score (torch.Tensor): The class score.

            label (torch.Tensor): The ground truth label.

            mask (torch.Tensor): The mask of tags. 0 indicates that the

                category of this tag is missing in the label of the video.

            category_mask (torch.Tensor): The category mask. For each sample,

                it's a tensor with length `len(self.categories)`, denotes that

                if the category is labeled for this video.



        Returns:

            torch.Tensor: The returned CrossEntropy loss.

        """

        if self.loss_type == 'all':
            loss_cls = F.binary_cross_entropy_with_logits(
                cls_score, label, reduction='none')
            if self.with_mask:
                w_loss_cls = mask * loss_cls
                w_loss_cls = torch.sum(w_loss_cls, dim=1)
                if self.reduction == 'mean':
                    w_loss_cls = w_loss_cls / torch.sum(mask, dim=1)
                w_loss_cls = torch.mean(w_loss_cls)
                return dict(loss_cls=w_loss_cls)

            if self.reduction == 'sum':
                loss_cls = torch.sum(loss_cls, dim=-1)
            return dict(loss_cls=torch.mean(loss_cls))

        if self.loss_type == 'individual':
            losses = {}
            loss_weights = {}
            for name, num, start_idx in zip(self.categories,
                                            self.category_nums,
                                            self.category_startidx):
                category_score = cls_score[:, start_idx:start_idx + num]
                category_label = label[:, start_idx:start_idx + num]
                category_loss = F.binary_cross_entropy_with_logits(
                    category_score, category_label, reduction='none')
                if self.reduction == 'mean':
                    category_loss = torch.mean(category_loss, dim=1)
                elif self.reduction == 'sum':
                    category_loss = torch.sum(category_loss, dim=1)

                idx = self.categories.index(name)
                if self.with_mask:
                    category_mask_i = category_mask[:, idx].reshape(-1)
                    # there should be at least one sample which contains tags
                    # in this category
                    if torch.sum(category_mask_i) < 0.5:
                        losses[f'{name}_LOSS'] = torch.tensor(
                            .0, device=get_device())
                        loss_weights[f'{name}_LOSS'] = .0
                        continue
                    category_loss = torch.sum(category_loss * category_mask_i)
                    category_loss = category_loss / torch.sum(category_mask_i)
                else:
                    category_loss = torch.mean(category_loss)
                # We name the loss of each category as 'LOSS', since we only
                # want to monitor them, not backward them. We will also provide
                # the loss used for backward in the losses dictionary
                losses[f'{name}_LOSS'] = category_loss
                loss_weights[f'{name}_LOSS'] = self.category_loss_weights[idx]
            loss_weight_sum = sum(loss_weights.values())
            loss_weights = {
                k: v / loss_weight_sum
                for k, v in loss_weights.items()
            }
            loss_cls = sum([losses[k] * loss_weights[k] for k in losses])
            losses['loss_cls'] = loss_cls
            # We also trace the loss weights
            losses.update({
                k + '_weight': torch.tensor(v).to(losses[k].device)
                for k, v in loss_weights.items()
            })
            # Note that the loss weights are just for reference.
            return losses
        else:
            raise ValueError("loss_type should be 'all' or 'individual', "
                             f'but got {self.loss_type}')