| | |
| | |
| | |
| | |
| |
|
| | import logging |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction='mean'): |
| | lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) |
| | return F.nll_loss( |
| | lprobs, target, ignore_index=ignore_index, reduction=reduction, |
| | ) |
| |
|
| |
|
| | try: |
| | import xentropy_cuda |
| | from apex.contrib import xentropy |
| |
|
| | logger.info('using fused cross entropy') |
| |
|
| | def cross_entropy(logits, target, ignore_index=-100, reduction='mean'): |
| | if logits.device == torch.device('cpu'): |
| | return _cross_entropy_pytorch(logits, target, ignore_index, reduction) |
| | else: |
| | half_to_float = (logits.dtype == torch.half) |
| | losses = xentropy.SoftmaxCrossEntropyLoss.apply( |
| | logits, target, 0.0, ignore_index, half_to_float, |
| | ) |
| | if reduction == 'sum': |
| | return losses.sum() |
| | elif reduction == 'mean': |
| | if ignore_index >= 0: |
| | return losses.sum() / target.ne(ignore_index).sum() |
| | else: |
| | return losses.mean() |
| | elif reduction == 'none': |
| | return losses |
| | else: |
| | raise NotImplementedError |
| |
|
| | except ImportError: |
| |
|
| | def cross_entropy(logits, target, ignore_index=-100, reduction='mean'): |
| | return _cross_entropy_pytorch(logits, target, ignore_index, reduction) |
| |
|