|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
__all__ = ["label_smooth", "CrossEntropyWithSoftTarget", "CrossEntropyWithLabelSmooth"] |
|
|
|
|
|
|
|
|
def label_smooth( |
|
|
target: torch.Tensor, n_classes: int, smooth_factor=0.1 |
|
|
) -> torch.Tensor: |
|
|
|
|
|
batch_size = target.shape[0] |
|
|
target = torch.unsqueeze(target, 1) |
|
|
soft_target = torch.zeros((batch_size, n_classes), device=target.device) |
|
|
soft_target.scatter_(1, target, 1) |
|
|
|
|
|
soft_target = torch.add( |
|
|
soft_target * (1 - smooth_factor), smooth_factor / n_classes |
|
|
) |
|
|
return soft_target |
|
|
|
|
|
|
|
|
class CrossEntropyWithSoftTarget: |
|
|
@staticmethod |
|
|
def get_loss(pred: torch.Tensor, soft_target: torch.Tensor) -> torch.Tensor: |
|
|
return torch.mean( |
|
|
torch.sum(-soft_target * F.log_softmax(pred, dim=-1, _stacklevel=5), 1) |
|
|
) |
|
|
|
|
|
def __call__(self, pred: torch.Tensor, soft_target: torch.Tensor) -> torch.Tensor: |
|
|
return self.get_loss(pred, soft_target) |
|
|
|
|
|
|
|
|
class CrossEntropyWithLabelSmooth: |
|
|
def __init__(self, smooth_ratio=0.1): |
|
|
super(CrossEntropyWithLabelSmooth, self).__init__() |
|
|
self.smooth_ratio = smooth_ratio |
|
|
|
|
|
def __call__(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
|
|
soft_target = label_smooth(target, pred.shape[1], self.smooth_ratio) |
|
|
return CrossEntropyWithSoftTarget.get_loss(pred, soft_target) |
|
|
|