File size: 1,398 Bytes
d008243 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import torch
import torch.nn.functional as F
__all__ = ["label_smooth", "CrossEntropyWithSoftTarget", "CrossEntropyWithLabelSmooth"]
def label_smooth(
target: torch.Tensor, n_classes: int, smooth_factor=0.1
) -> torch.Tensor:
# convert to one-hot
batch_size = target.shape[0]
target = torch.unsqueeze(target, 1)
soft_target = torch.zeros((batch_size, n_classes), device=target.device)
soft_target.scatter_(1, target, 1)
# label smoothing
soft_target = torch.add(
soft_target * (1 - smooth_factor), smooth_factor / n_classes
)
return soft_target
class CrossEntropyWithSoftTarget:
@staticmethod
def get_loss(pred: torch.Tensor, soft_target: torch.Tensor) -> torch.Tensor:
return torch.mean(
torch.sum(-soft_target * F.log_softmax(pred, dim=-1, _stacklevel=5), 1)
)
def __call__(self, pred: torch.Tensor, soft_target: torch.Tensor) -> torch.Tensor:
return self.get_loss(pred, soft_target)
class CrossEntropyWithLabelSmooth:
def __init__(self, smooth_ratio=0.1):
super(CrossEntropyWithLabelSmooth, self).__init__()
self.smooth_ratio = smooth_ratio
def __call__(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
soft_target = label_smooth(target, pred.shape[1], self.smooth_ratio)
return CrossEntropyWithSoftTarget.get_loss(pred, soft_target)
|