|
|
|
|
|
from typing import List, Optional
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
from mmaction.registry import MODELS
|
|
|
from .base import BaseWeightedLoss
|
|
|
|
|
|
|
|
|
@MODELS.register_module()
|
|
|
class CrossEntropyLoss(BaseWeightedLoss):
|
|
|
"""Cross Entropy Loss.
|
|
|
|
|
|
Support two kinds of labels and their corresponding loss type. It's worth
|
|
|
mentioning that loss type will be detected by the shape of ``cls_score``
|
|
|
and ``label``.
|
|
|
1) Hard label: This label is an integer array and all of the elements are
|
|
|
in the range [0, num_classes - 1]. This label's shape should be
|
|
|
``cls_score``'s shape with the `num_classes` dimension removed.
|
|
|
2) Soft label(probability distribution over classes): This label is a
|
|
|
probability distribution and all of the elements are in the range
|
|
|
[0, 1]. This label's shape must be the same as ``cls_score``. For now,
|
|
|
only 2-dim soft label is supported.
|
|
|
|
|
|
Args:
|
|
|
loss_weight (float): Factor scalar multiplied on the loss.
|
|
|
Defaults to 1.0.
|
|
|
class_weight (list[float] | None): Loss weight for each class. If set
|
|
|
as None, use the same weight 1 for all classes. Only applies
|
|
|
to CrossEntropyLoss and BCELossWithLogits (should not be set when
|
|
|
using other losses). Defaults to None.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
loss_weight: float = 1.0,
|
|
|
class_weight: Optional[List[float]] = None) -> None:
|
|
|
super().__init__(loss_weight=loss_weight)
|
|
|
self.class_weight = None
|
|
|
if class_weight is not None:
|
|
|
self.class_weight = torch.Tensor(class_weight)
|
|
|
|
|
|
def _forward(self, cls_score: torch.Tensor, label: torch.Tensor,
|
|
|
**kwargs) -> torch.Tensor:
|
|
|
"""Forward function.
|
|
|
|
|
|
Args:
|
|
|
cls_score (torch.Tensor): The class score.
|
|
|
label (torch.Tensor): The ground truth label.
|
|
|
kwargs: Any keyword argument to be used to calculate
|
|
|
CrossEntropy loss.
|
|
|
|
|
|
Returns:
|
|
|
torch.Tensor: The returned CrossEntropy loss.
|
|
|
"""
|
|
|
if cls_score.size() == label.size():
|
|
|
|
|
|
|
|
|
assert cls_score.dim() == 2, 'Only support 2-dim soft label'
|
|
|
assert len(kwargs) == 0, \
|
|
|
('For now, no extra args are supported for soft label, '
|
|
|
f'but get {kwargs}')
|
|
|
|
|
|
lsm = F.log_softmax(cls_score, 1)
|
|
|
if self.class_weight is not None:
|
|
|
self.class_weight = self.class_weight.to(cls_score.device)
|
|
|
lsm = lsm * self.class_weight.unsqueeze(0)
|
|
|
loss_cls = -(label * lsm).sum(1)
|
|
|
|
|
|
|
|
|
if self.class_weight is not None:
|
|
|
|
|
|
|
|
|
loss_cls = loss_cls.sum() / torch.sum(
|
|
|
self.class_weight.unsqueeze(0) * label)
|
|
|
else:
|
|
|
loss_cls = loss_cls.mean()
|
|
|
else:
|
|
|
|
|
|
|
|
|
if self.class_weight is not None:
|
|
|
assert 'weight' not in kwargs, \
|
|
|
"The key 'weight' already exists."
|
|
|
kwargs['weight'] = self.class_weight.to(cls_score.device)
|
|
|
loss_cls = F.cross_entropy(cls_score, label, **kwargs)
|
|
|
|
|
|
return loss_cls
|
|
|
|
|
|
|
|
|
@MODELS.register_module()
|
|
|
class BCELossWithLogits(BaseWeightedLoss):
|
|
|
"""Binary Cross Entropy Loss with logits.
|
|
|
|
|
|
Args:
|
|
|
loss_weight (float): Factor scalar multiplied on the loss.
|
|
|
Defaults to 1.0.
|
|
|
class_weight (list[float] | None): Loss weight for each class. If set
|
|
|
as None, use the same weight 1 for all classes. Only applies
|
|
|
to CrossEntropyLoss and BCELossWithLogits (should not be set when
|
|
|
using other losses). Defaults to None.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
loss_weight: float = 1.0,
|
|
|
class_weight: Optional[List[float]] = None) -> None:
|
|
|
super().__init__(loss_weight=loss_weight)
|
|
|
self.class_weight = None
|
|
|
if class_weight is not None:
|
|
|
self.class_weight = torch.Tensor(class_weight)
|
|
|
|
|
|
def _forward(self, cls_score: torch.Tensor, label: torch.Tensor,
|
|
|
**kwargs) -> torch.Tensor:
|
|
|
"""Forward function.
|
|
|
|
|
|
Args:
|
|
|
cls_score (torch.Tensor): The class score.
|
|
|
label (torch.Tensor): The ground truth label.
|
|
|
kwargs: Any keyword argument to be used to calculate
|
|
|
bce loss with logits.
|
|
|
|
|
|
Returns:
|
|
|
torch.Tensor: The returned bce loss with logits.
|
|
|
"""
|
|
|
if self.class_weight is not None:
|
|
|
assert 'weight' not in kwargs, "The key 'weight' already exists."
|
|
|
kwargs['weight'] = self.class_weight.to(cls_score.device)
|
|
|
loss_cls = F.binary_cross_entropy_with_logits(cls_score, label,
|
|
|
**kwargs)
|
|
|
return loss_cls
|
|
|
|
|
|
|
|
|
@MODELS.register_module()
|
|
|
class CBFocalLoss(BaseWeightedLoss):
|
|
|
"""Class Balanced Focal Loss. Adapted from https://github.com/abhinanda-
|
|
|
punnakkal/BABEL/. This loss is used in the skeleton-based action
|
|
|
recognition baseline for BABEL.
|
|
|
|
|
|
Args:
|
|
|
loss_weight (float): Factor scalar multiplied on the loss.
|
|
|
Defaults to 1.0.
|
|
|
samples_per_cls (list[int]): The number of samples per class.
|
|
|
Defaults to [].
|
|
|
beta (float): Hyperparameter that controls the per class loss weight.
|
|
|
Defaults to 0.9999.
|
|
|
gamma (float): Hyperparameter of the focal loss. Defaults to 2.0.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
loss_weight: float = 1.0,
|
|
|
samples_per_cls: List[int] = [],
|
|
|
beta: float = 0.9999,
|
|
|
gamma: float = 2.) -> None:
|
|
|
super().__init__(loss_weight=loss_weight)
|
|
|
self.samples_per_cls = samples_per_cls
|
|
|
self.beta = beta
|
|
|
self.gamma = gamma
|
|
|
effective_num = 1.0 - np.power(beta, samples_per_cls)
|
|
|
weights = (1.0 - beta) / np.array(effective_num)
|
|
|
weights = weights / np.sum(weights) * len(weights)
|
|
|
self.weights = weights
|
|
|
self.num_classes = len(weights)
|
|
|
|
|
|
def _forward(self, cls_score: torch.Tensor, label: torch.Tensor,
|
|
|
**kwargs) -> torch.Tensor:
|
|
|
"""Forward function.
|
|
|
|
|
|
Args:
|
|
|
cls_score (torch.Tensor): The class score.
|
|
|
label (torch.Tensor): The ground truth label.
|
|
|
kwargs: Any keyword argument to be used to calculate
|
|
|
bce loss with logits.
|
|
|
|
|
|
Returns:
|
|
|
torch.Tensor: The returned bce loss with logits.
|
|
|
"""
|
|
|
weights = torch.tensor(self.weights).float().to(cls_score.device)
|
|
|
label_one_hot = F.one_hot(label, self.num_classes).float()
|
|
|
weights = weights.unsqueeze(0)
|
|
|
weights = weights.repeat(label_one_hot.shape[0], 1) * label_one_hot
|
|
|
weights = weights.sum(1)
|
|
|
weights = weights.unsqueeze(1)
|
|
|
weights = weights.repeat(1, self.num_classes)
|
|
|
|
|
|
BCELoss = F.binary_cross_entropy_with_logits(
|
|
|
input=cls_score, target=label_one_hot, reduction='none')
|
|
|
|
|
|
modulator = 1.0
|
|
|
if self.gamma:
|
|
|
modulator = torch.exp(-self.gamma * label_one_hot * cls_score -
|
|
|
self.gamma *
|
|
|
torch.log(1 + torch.exp(-1.0 * cls_score)))
|
|
|
|
|
|
loss = modulator * BCELoss
|
|
|
weighted_loss = weights * loss
|
|
|
|
|
|
focal_loss = torch.sum(weighted_loss)
|
|
|
focal_loss /= torch.sum(label_one_hot)
|
|
|
|
|
|
return focal_loss
|
|
|
|