maotao / fairseq /modules /cross_entropy.py
julse's picture
Upload 551 files
be611b4 verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch
import torch.nn.functional as F
logger = logging.getLogger(__name__)
def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction='mean'):
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
return F.nll_loss(
lprobs, target, ignore_index=ignore_index, reduction=reduction,
)
try:
import xentropy_cuda
from apex.contrib import xentropy
logger.info('using fused cross entropy')
def cross_entropy(logits, target, ignore_index=-100, reduction='mean'):
if logits.device == torch.device('cpu'):
return _cross_entropy_pytorch(logits, target, ignore_index, reduction)
else:
half_to_float = (logits.dtype == torch.half)
losses = xentropy.SoftmaxCrossEntropyLoss.apply(
logits, target, 0.0, ignore_index, half_to_float,
)
if reduction == 'sum':
return losses.sum()
elif reduction == 'mean':
if ignore_index >= 0:
return losses.sum() / target.ne(ignore_index).sum()
else:
return losses.mean()
elif reduction == 'none':
return losses
else:
raise NotImplementedError
except ImportError:
def cross_entropy(logits, target, ignore_index=-100, reduction='mean'):
return _cross_entropy_pytorch(logits, target, ignore_index, reduction)