sayest / models.py
Aleksei Žavoronkov
add duration scores functionality
32fef6f
import torch
import torch.nn as nn
class OrdinalLogLoss(nn.Module):
def __init__(
self,
num_classes,
alpha=1.0,
reduction='mean',
distance_matrix=None,
class_weights=None,
eps=1e-8
):
super(OrdinalLogLoss, self).__init__()
self.num_classes = num_classes
self.alpha = alpha
self.reduction = reduction
self.eps = eps
if distance_matrix is not None:
assert distance_matrix.shape == (num_classes, num_classes), \
"Distance matrix must be of shape (num_classes, num_classes)"
self.register_buffer('distance_matrix', distance_matrix.float())
else:
idx = torch.arange(num_classes).float()
default_matrix = torch.abs(idx.unsqueeze(0) - idx.unsqueeze(1))
self.register_buffer('distance_matrix', default_matrix)
if class_weights is not None:
assert class_weights.shape == (num_classes,), \
"Class weights must be of shape (num_classes,)"
self.register_buffer('class_weights', class_weights.float())
else:
self.class_weights = None
def forward(self, logits, target):
probs = torch.softmax(logits, dim=1).clamp(max=1 - self.eps)
distances = self.distance_matrix[target] ** self.alpha
per_class_loss = -torch.log(1 - probs + self.eps)
loss = (per_class_loss * distances).sum(dim=1) # shape (batch_size,)
# Apply class weights
if self.class_weights is not None:
sample_weights = self.class_weights[target]
loss = loss * sample_weights
# Apply reduction
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss