|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class OrdinalLogLoss(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
num_classes, |
|
|
alpha=1.0, |
|
|
reduction='mean', |
|
|
distance_matrix=None, |
|
|
class_weights=None, |
|
|
eps=1e-8 |
|
|
): |
|
|
super(OrdinalLogLoss, self).__init__() |
|
|
self.num_classes = num_classes |
|
|
self.alpha = alpha |
|
|
self.reduction = reduction |
|
|
self.eps = eps |
|
|
|
|
|
if distance_matrix is not None: |
|
|
assert distance_matrix.shape == (num_classes, num_classes), \ |
|
|
"Distance matrix must be of shape (num_classes, num_classes)" |
|
|
self.register_buffer('distance_matrix', distance_matrix.float()) |
|
|
else: |
|
|
idx = torch.arange(num_classes).float() |
|
|
default_matrix = torch.abs(idx.unsqueeze(0) - idx.unsqueeze(1)) |
|
|
self.register_buffer('distance_matrix', default_matrix) |
|
|
|
|
|
if class_weights is not None: |
|
|
assert class_weights.shape == (num_classes,), \ |
|
|
"Class weights must be of shape (num_classes,)" |
|
|
self.register_buffer('class_weights', class_weights.float()) |
|
|
else: |
|
|
self.class_weights = None |
|
|
|
|
|
def forward(self, logits, target): |
|
|
probs = torch.softmax(logits, dim=1).clamp(max=1 - self.eps) |
|
|
distances = self.distance_matrix[target] ** self.alpha |
|
|
per_class_loss = -torch.log(1 - probs + self.eps) |
|
|
loss = (per_class_loss * distances).sum(dim=1) |
|
|
|
|
|
|
|
|
if self.class_weights is not None: |
|
|
sample_weights = self.class_weights[target] |
|
|
loss = loss * sample_weights |
|
|
|
|
|
|
|
|
if self.reduction == 'mean': |
|
|
return loss.mean() |
|
|
elif self.reduction == 'sum': |
|
|
return loss.sum() |
|
|
else: |
|
|
return loss |
|
|
|
|
|
|