| | from fastai.vision import * |
| |
|
| | from modules.model import Model |
| |
|
| |
|
| | class MultiLosses(nn.Module): |
| | def __init__(self, one_hot=True): |
| | super().__init__() |
| | self.ce = SoftCrossEntropyLoss() if one_hot else torch.nn.CrossEntropyLoss() |
| | self.bce = torch.nn.BCELoss() |
| |
|
| | @property |
| | def last_losses(self): |
| | return self.losses |
| |
|
| | def _flatten(self, sources, lengths): |
| | return torch.cat([t[:l] for t, l in zip(sources, lengths)]) |
| |
|
| | def _merge_list(self, all_res): |
| | if not isinstance(all_res, (list, tuple)): |
| | return all_res |
| | def merge(items): |
| | if isinstance(items[0], torch.Tensor): return torch.cat(items, dim=0) |
| | else: return items[0] |
| | res = dict() |
| | for key in all_res[0].keys(): |
| | items = [r[key] for r in all_res] |
| | res[key] = merge(items) |
| | return res |
| |
|
| | def _ce_loss(self, output, gt_labels, gt_lengths, idx=None, record=True): |
| | loss_name = output.get('name') |
| | pt_logits, weight = output['logits'], output['loss_weight'] |
| |
|
| | assert pt_logits.shape[0] % gt_labels.shape[0] == 0 |
| | iter_size = pt_logits.shape[0] // gt_labels.shape[0] |
| | if iter_size > 1: |
| | gt_labels = gt_labels.repeat(iter_size, 1, 1) |
| | gt_lengths = gt_lengths.repeat(iter_size) |
| | flat_gt_labels = self._flatten(gt_labels, gt_lengths) |
| | flat_pt_logits = self._flatten(pt_logits, gt_lengths) |
| |
|
| | nll = output.get('nll') |
| | if nll is not None: |
| | loss = self.ce(flat_pt_logits, flat_gt_labels, softmax=False) * weight |
| | else: |
| | loss = self.ce(flat_pt_logits, flat_gt_labels) * weight |
| | if record and loss_name is not None: self.losses[f'{loss_name}_loss'] = loss |
| |
|
| | return loss |
| |
|
| | def forward(self, outputs, *args): |
| | self.losses = {} |
| | if isinstance(outputs, (tuple, list)): |
| | outputs = [self._merge_list(o) for o in outputs] |
| | return sum([self._ce_loss(o, *args) for o in outputs if o['loss_weight'] > 0.]) |
| | else: |
| | return self._ce_loss(outputs, *args, record=False) |
| |
|
| |
|
| | class SoftCrossEntropyLoss(nn.Module): |
| | def __init__(self, reduction="mean"): |
| | super().__init__() |
| | self.reduction = reduction |
| |
|
| | def forward(self, input, target, softmax=True): |
| | if softmax: log_prob = F.log_softmax(input, dim=-1) |
| | else: log_prob = torch.log(input) |
| | loss = -(target * log_prob).sum(dim=-1) |
| | if self.reduction == "mean": return loss.mean() |
| | elif self.reduction == "sum": return loss.sum() |
| | else: return loss |
| |
|