Other
English
Shanci's picture
Upload folder using huggingface_hub
26225c5 verified
from torch.nn import L1Loss as TorchL1Loss
from src.loss.weighted import WeightedLossMixIn
__all__ = ['WeightedL1Loss', 'L1Loss']
class WeightedL1Loss(WeightedLossMixIn, TorchL1Loss):
"""Weighted L1 loss between predicted and target offsets. This is
basically the L1Loss except that positive weights must be passed at
forward time to give more importance to some items.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, reduction='none', **kwargs)
class L1Loss(WeightedL1Loss):
"""L1 loss between predicted and target offsets.
The forward signature allows using this loss as a weighted loss,
with input weights ignored.
"""
def forward(self, input, target, weight):
return super().forward(input, target, None)