English
Shanci's picture
Upload folder using huggingface_hub
26225c5 verified
raw
history blame contribute delete
960 Bytes
import torch
__all__ = ['loss_with_sample_weights', 'loss_with_target_histogram']
def loss_with_sample_weights(criterion, pred, y, weights):
assert weights.dim() == 1
assert pred.shape[0] == y.shape[0] == weights.shape[0]
reduction_backup = criterion.reduction
criterion.reduction = 'none'
weights = weights.float() / weights.sum()
loss = criterion(pred, y)
loss = loss.sum(dim=1) if loss.dim() > 1 else loss
loss = (loss * weights).sum()
criterion.reduction = reduction_backup
return loss
def loss_with_target_histogram(criterion, pred, y_hist):
assert pred.dim() == 2
assert y_hist.dim() == 2
assert pred.shape[0] == y_hist.shape[0]
y_mask = y_hist != 0
logits_flat = pred.repeat_interleave(y_mask.sum(dim=1), dim=0)
y_flat = torch.where(y_mask)[1]
weights = y_hist[y_mask]
loss = loss_with_sample_weights(
criterion, logits_flat, y_flat, weights)
return loss