| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from uniperceiver.config import configurable |
| | from .build import LOSSES_REGISTRY |
| |
|
| | @LOSSES_REGISTRY.register() |
| | class Accuracy(nn.Module): |
| | @configurable |
| | def __init__( |
| | self |
| | ): |
| | super(Accuracy, self).__init__() |
| |
|
| | @classmethod |
| | def from_config(cls, cfg): |
| | return { |
| | } |
| |
|
| | def Forward(self, logits, targets): |
| | pred = torch.argmax(logits.view(-1, logits.shape[-1]), -1) |
| | targets = targets.view(-1) |
| | return torch.mean((pred == targets).float()) |
| |
|
| | def forward(self, outputs_dict): |
| |
|
| | ret = {} |
| | for logit, target, loss_identification in zip(outputs_dict['logits'], |
| | outputs_dict['targets'], |
| | outputs_dict['loss_names']): |
| | if logit.shape == target.shape: |
| | |
| | target = torch.argmax(target, dim=-1) |
| | acc = self.Forward(logit, target) |
| | loss_name = 'Accuracy' |
| | if len(loss_identification) > 0: |
| | loss_name = loss_name + f' ({loss_identification})' |
| | ret.update({loss_name: acc}) |
| |
|
| | return ret |