File size: 771 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn.functional as F

from mmaction.registry import MODELS
from .base import BaseWeightedLoss


@MODELS.register_module()
class NLLLoss(BaseWeightedLoss):
    """NLL Loss.



    It will calculate NLL loss given cls_score and label.

    """

    def _forward(self, cls_score, label, **kwargs):
        """Forward function.



        Args:

            cls_score (torch.Tensor): The class score.

            label (torch.Tensor): The ground truth label.

            kwargs: Any keyword argument to be used to calculate nll loss.



        Returns:

            torch.Tensor: The returned nll loss.

        """
        loss_cls = F.nll_loss(cls_score, label, **kwargs)
        return loss_cls