import torch __all__ = ['loss_with_sample_weights', 'loss_with_target_histogram'] def loss_with_sample_weights(criterion, pred, y, weights): assert weights.dim() == 1 assert pred.shape[0] == y.shape[0] == weights.shape[0] reduction_backup = criterion.reduction criterion.reduction = 'none' weights = weights.float() / weights.sum() loss = criterion(pred, y) loss = loss.sum(dim=1) if loss.dim() > 1 else loss loss = (loss * weights).sum() criterion.reduction = reduction_backup return loss def loss_with_target_histogram(criterion, pred, y_hist): assert pred.dim() == 2 assert y_hist.dim() == 2 assert pred.shape[0] == y_hist.shape[0] y_mask = y_hist != 0 logits_flat = pred.repeat_interleave(y_mask.sum(dim=1), dim=0) y_flat = torch.where(y_mask)[1] weights = y_hist[y_mask] loss = loss_with_sample_weights( criterion, logits_flat, y_flat, weights) return loss