| import torch.nn as nn |
| from .abstract_loss_func import AbstractLossClass |
| from metrics.registry import LOSSFUNC |
|
|
|
|
| @LOSSFUNC.register_module(module_name="cross_entropy") |
| class CrossEntropyLoss(AbstractLossClass): |
| def __init__(self): |
| super().__init__() |
| self.loss_fn = nn.CrossEntropyLoss() |
|
|
| def forward(self, inputs, targets): |
| """ |
| Computes the cross-entropy loss. |
| |
| Args: |
| inputs: A PyTorch tensor of size (batch_size, num_classes) containing the predicted scores. |
| targets: A PyTorch tensor of size (batch_size) containing the ground-truth class indices. |
| |
| Returns: |
| A scalar tensor representing the cross-entropy loss. |
| """ |
| |
| loss = self.loss_fn(inputs, targets) |
|
|
| return loss |