File size: 6,958 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from mmengine.device import get_device
from mmaction.registry import MODELS
from .base import BaseWeightedLoss
@MODELS.register_module()
class HVULoss(BaseWeightedLoss):
"""Calculate the BCELoss for HVU.
Args:
categories (tuple[str]): Names of tag categories, tags are organized in
this order. Default: ['action', 'attribute', 'concept', 'event',
'object', 'scene'].
category_nums (tuple[int]): Number of tags for each category. Default:
(739, 117, 291, 69, 1678, 248).
category_loss_weights (tuple[float]): Loss weights of categories, it
applies only if `loss_type == 'individual'`. The loss weights will
be normalized so that the sum equals to 1, so that you can give any
positive number as loss weight. Default: (1, 1, 1, 1, 1, 1).
loss_type (str): The loss type we calculate, we can either calculate
the BCELoss for all tags, or calculate the BCELoss for tags in each
category. Choices are 'individual' or 'all'. Default: 'all'.
with_mask (bool): Since some tag categories are missing for some video
clips. If `with_mask == True`, we will not calculate loss for these
missing categories. Otherwise, these missing categories are treated
as negative samples.
reduction (str): Reduction way. Choices are 'mean' or 'sum'. Default:
'mean'.
loss_weight (float): The loss weight. Default: 1.0.
"""
def __init__(self,
categories=('action', 'attribute', 'concept', 'event',
'object', 'scene'),
category_nums=(739, 117, 291, 69, 1678, 248),
category_loss_weights=(1, 1, 1, 1, 1, 1),
loss_type='all',
with_mask=False,
reduction='mean',
loss_weight=1.0):
super().__init__(loss_weight)
self.categories = categories
self.category_nums = category_nums
self.category_loss_weights = category_loss_weights
assert len(self.category_nums) == len(self.category_loss_weights)
for category_loss_weight in self.category_loss_weights:
assert category_loss_weight >= 0
self.loss_type = loss_type
self.with_mask = with_mask
self.reduction = reduction
self.category_startidx = [0]
for i in range(len(self.category_nums) - 1):
self.category_startidx.append(self.category_startidx[-1] +
self.category_nums[i])
assert self.loss_type in ['individual', 'all']
assert self.reduction in ['mean', 'sum']
def _forward(self, cls_score, label, mask, category_mask):
"""Forward function.
Args:
cls_score (torch.Tensor): The class score.
label (torch.Tensor): The ground truth label.
mask (torch.Tensor): The mask of tags. 0 indicates that the
category of this tag is missing in the label of the video.
category_mask (torch.Tensor): The category mask. For each sample,
it's a tensor with length `len(self.categories)`, denotes that
if the category is labeled for this video.
Returns:
torch.Tensor: The returned CrossEntropy loss.
"""
if self.loss_type == 'all':
loss_cls = F.binary_cross_entropy_with_logits(
cls_score, label, reduction='none')
if self.with_mask:
w_loss_cls = mask * loss_cls
w_loss_cls = torch.sum(w_loss_cls, dim=1)
if self.reduction == 'mean':
w_loss_cls = w_loss_cls / torch.sum(mask, dim=1)
w_loss_cls = torch.mean(w_loss_cls)
return dict(loss_cls=w_loss_cls)
if self.reduction == 'sum':
loss_cls = torch.sum(loss_cls, dim=-1)
return dict(loss_cls=torch.mean(loss_cls))
if self.loss_type == 'individual':
losses = {}
loss_weights = {}
for name, num, start_idx in zip(self.categories,
self.category_nums,
self.category_startidx):
category_score = cls_score[:, start_idx:start_idx + num]
category_label = label[:, start_idx:start_idx + num]
category_loss = F.binary_cross_entropy_with_logits(
category_score, category_label, reduction='none')
if self.reduction == 'mean':
category_loss = torch.mean(category_loss, dim=1)
elif self.reduction == 'sum':
category_loss = torch.sum(category_loss, dim=1)
idx = self.categories.index(name)
if self.with_mask:
category_mask_i = category_mask[:, idx].reshape(-1)
# there should be at least one sample which contains tags
# in this category
if torch.sum(category_mask_i) < 0.5:
losses[f'{name}_LOSS'] = torch.tensor(
.0, device=get_device())
loss_weights[f'{name}_LOSS'] = .0
continue
category_loss = torch.sum(category_loss * category_mask_i)
category_loss = category_loss / torch.sum(category_mask_i)
else:
category_loss = torch.mean(category_loss)
# We name the loss of each category as 'LOSS', since we only
# want to monitor them, not backward them. We will also provide
# the loss used for backward in the losses dictionary
losses[f'{name}_LOSS'] = category_loss
loss_weights[f'{name}_LOSS'] = self.category_loss_weights[idx]
loss_weight_sum = sum(loss_weights.values())
loss_weights = {
k: v / loss_weight_sum
for k, v in loss_weights.items()
}
loss_cls = sum([losses[k] * loss_weights[k] for k in losses])
losses['loss_cls'] = loss_cls
# We also trace the loss weights
losses.update({
k + '_weight': torch.tensor(v).to(losses[k].device)
for k, v in loss_weights.items()
})
# Note that the loss weights are just for reference.
return losses
else:
raise ValueError("loss_type should be 'all' or 'individual', "
f'but got {self.loss_type}')
|