| | |
| | |
| | |
| | |
| |
|
| | import logging |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction="mean"): |
| | lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) |
| | return F.nll_loss( |
| | lprobs, |
| | target, |
| | ignore_index=ignore_index, |
| | reduction=reduction, |
| | ) |
| |
|
| |
|
| | try: |
| | import xentropy_cuda |
| | from apex.contrib import xentropy |
| |
|
| | def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): |
| | if logits.device == torch.device("cpu"): |
| | return _cross_entropy_pytorch(logits, target, ignore_index, reduction) |
| | else: |
| | if not getattr(cross_entropy, "_has_logged_once", False): |
| | logger.info("using fused cross entropy") |
| | cross_entropy._has_logged_once = True |
| |
|
| | half_to_float = logits.dtype == torch.half |
| | losses = xentropy.SoftmaxCrossEntropyLoss.apply( |
| | logits, |
| | target, |
| | 0.0, |
| | ignore_index, |
| | half_to_float, |
| | ) |
| | if reduction == "sum": |
| | return losses.sum() |
| | elif reduction == "mean": |
| | if ignore_index >= 0: |
| | return losses.sum() / target.ne(ignore_index).sum() |
| | else: |
| | return losses.mean() |
| | elif reduction == "none": |
| | return losses |
| | else: |
| | raise NotImplementedError |
| |
|
| |
|
| | except ImportError: |
| |
|
| | def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): |
| | return _cross_entropy_pytorch(logits, target, ignore_index, reduction) |
| |
|