| | import torch |
| | import torch.nn as nn |
| | from uniperceiver.config import configurable |
| | from .build import LOSSES_REGISTRY |
| |
|
| | @LOSSES_REGISTRY.register() |
| | class BCEWithLogits(nn.Module): |
| | @configurable |
| | def __init__(self, loss_weight=1.0): |
| | super(BCEWithLogits, self).__init__() |
| | self.criterion = nn.BCEWithLogitsLoss(reduction="mean") |
| | if not isinstance(loss_weight, float): |
| | self.loss_weight = 1.0 |
| | else: |
| | self.loss_weight = loss_weight |
| |
|
| | @classmethod |
| | def from_config(cls, cfg): |
| | return { |
| | 'loss_weight' : getattr(cfg.LOSSES, 'LOSS_WEIGHT', 1.0) |
| | } |
| |
|
| | def forward(self, outputs_dict): |
| | ret = {} |
| | for logit, target, loss_identification in zip(outputs_dict['logits'], |
| | outputs_dict['targets'], |
| | outputs_dict['loss_names']): |
| |
|
| | loss = self.criterion(logit, target) |
| | if self.loss_weight != 1.0: |
| | loss *= self.loss_weight |
| | loss_name = 'BCEWithLogits_Loss' |
| | if len(loss_identification) > 0: |
| | loss_name = loss_name+ f' ({loss_identification})' |
| | ret.update({ loss_name: loss }) |
| | |
| | return ret |
| | |