File size: 8,085 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
import numpy as np
import torch
import torch.nn.functional as F
from mmaction.registry import MODELS
from .base import BaseWeightedLoss
@MODELS.register_module()
class CrossEntropyLoss(BaseWeightedLoss):
"""Cross Entropy Loss.
Support two kinds of labels and their corresponding loss type. It's worth
mentioning that loss type will be detected by the shape of ``cls_score``
and ``label``.
1) Hard label: This label is an integer array and all of the elements are
in the range [0, num_classes - 1]. This label's shape should be
``cls_score``'s shape with the `num_classes` dimension removed.
2) Soft label(probability distribution over classes): This label is a
probability distribution and all of the elements are in the range
[0, 1]. This label's shape must be the same as ``cls_score``. For now,
only 2-dim soft label is supported.
Args:
loss_weight (float): Factor scalar multiplied on the loss.
Defaults to 1.0.
class_weight (list[float] | None): Loss weight for each class. If set
as None, use the same weight 1 for all classes. Only applies
to CrossEntropyLoss and BCELossWithLogits (should not be set when
using other losses). Defaults to None.
"""
def __init__(self,
loss_weight: float = 1.0,
class_weight: Optional[List[float]] = None) -> None:
super().__init__(loss_weight=loss_weight)
self.class_weight = None
if class_weight is not None:
self.class_weight = torch.Tensor(class_weight)
def _forward(self, cls_score: torch.Tensor, label: torch.Tensor,
**kwargs) -> torch.Tensor:
"""Forward function.
Args:
cls_score (torch.Tensor): The class score.
label (torch.Tensor): The ground truth label.
kwargs: Any keyword argument to be used to calculate
CrossEntropy loss.
Returns:
torch.Tensor: The returned CrossEntropy loss.
"""
if cls_score.size() == label.size():
# calculate loss for soft label
assert cls_score.dim() == 2, 'Only support 2-dim soft label'
assert len(kwargs) == 0, \
('For now, no extra args are supported for soft label, '
f'but get {kwargs}')
lsm = F.log_softmax(cls_score, 1)
if self.class_weight is not None:
self.class_weight = self.class_weight.to(cls_score.device)
lsm = lsm * self.class_weight.unsqueeze(0)
loss_cls = -(label * lsm).sum(1)
# default reduction 'mean'
if self.class_weight is not None:
# Use weighted average as pytorch CrossEntropyLoss does.
# For more information, please visit https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html # noqa
loss_cls = loss_cls.sum() / torch.sum(
self.class_weight.unsqueeze(0) * label)
else:
loss_cls = loss_cls.mean()
else:
# calculate loss for hard label
if self.class_weight is not None:
assert 'weight' not in kwargs, \
"The key 'weight' already exists."
kwargs['weight'] = self.class_weight.to(cls_score.device)
loss_cls = F.cross_entropy(cls_score, label, **kwargs)
return loss_cls
@MODELS.register_module()
class BCELossWithLogits(BaseWeightedLoss):
"""Binary Cross Entropy Loss with logits.
Args:
loss_weight (float): Factor scalar multiplied on the loss.
Defaults to 1.0.
class_weight (list[float] | None): Loss weight for each class. If set
as None, use the same weight 1 for all classes. Only applies
to CrossEntropyLoss and BCELossWithLogits (should not be set when
using other losses). Defaults to None.
"""
def __init__(self,
loss_weight: float = 1.0,
class_weight: Optional[List[float]] = None) -> None:
super().__init__(loss_weight=loss_weight)
self.class_weight = None
if class_weight is not None:
self.class_weight = torch.Tensor(class_weight)
def _forward(self, cls_score: torch.Tensor, label: torch.Tensor,
**kwargs) -> torch.Tensor:
"""Forward function.
Args:
cls_score (torch.Tensor): The class score.
label (torch.Tensor): The ground truth label.
kwargs: Any keyword argument to be used to calculate
bce loss with logits.
Returns:
torch.Tensor: The returned bce loss with logits.
"""
if self.class_weight is not None:
assert 'weight' not in kwargs, "The key 'weight' already exists."
kwargs['weight'] = self.class_weight.to(cls_score.device)
loss_cls = F.binary_cross_entropy_with_logits(cls_score, label,
**kwargs)
return loss_cls
@MODELS.register_module()
class CBFocalLoss(BaseWeightedLoss):
"""Class Balanced Focal Loss. Adapted from https://github.com/abhinanda-
punnakkal/BABEL/. This loss is used in the skeleton-based action
recognition baseline for BABEL.
Args:
loss_weight (float): Factor scalar multiplied on the loss.
Defaults to 1.0.
samples_per_cls (list[int]): The number of samples per class.
Defaults to [].
beta (float): Hyperparameter that controls the per class loss weight.
Defaults to 0.9999.
gamma (float): Hyperparameter of the focal loss. Defaults to 2.0.
"""
def __init__(self,
loss_weight: float = 1.0,
samples_per_cls: List[int] = [],
beta: float = 0.9999,
gamma: float = 2.) -> None:
super().__init__(loss_weight=loss_weight)
self.samples_per_cls = samples_per_cls
self.beta = beta
self.gamma = gamma
effective_num = 1.0 - np.power(beta, samples_per_cls)
weights = (1.0 - beta) / np.array(effective_num)
weights = weights / np.sum(weights) * len(weights)
self.weights = weights
self.num_classes = len(weights)
def _forward(self, cls_score: torch.Tensor, label: torch.Tensor,
**kwargs) -> torch.Tensor:
"""Forward function.
Args:
cls_score (torch.Tensor): The class score.
label (torch.Tensor): The ground truth label.
kwargs: Any keyword argument to be used to calculate
bce loss with logits.
Returns:
torch.Tensor: The returned bce loss with logits.
"""
weights = torch.tensor(self.weights).float().to(cls_score.device)
label_one_hot = F.one_hot(label, self.num_classes).float()
weights = weights.unsqueeze(0)
weights = weights.repeat(label_one_hot.shape[0], 1) * label_one_hot
weights = weights.sum(1)
weights = weights.unsqueeze(1)
weights = weights.repeat(1, self.num_classes)
BCELoss = F.binary_cross_entropy_with_logits(
input=cls_score, target=label_one_hot, reduction='none')
modulator = 1.0
if self.gamma:
modulator = torch.exp(-self.gamma * label_one_hot * cls_score -
self.gamma *
torch.log(1 + torch.exp(-1.0 * cls_score)))
loss = modulator * BCELoss
weighted_loss = weights * loss
focal_loss = torch.sum(weighted_loss)
focal_loss /= torch.sum(label_one_hot)
return focal_loss
|