File size: 502 Bytes
4c1e73e
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
import torch.nn as nn

class HuberLoss(nn.Module):
    """Huber Loss (a.k.a. Smooth L1)"""
    def __init__(self, delta=1.0, reduction='mean'):
        super().__init__()
        self.delta = delta
        self.reduction = reduction

    def forward(self, inputs, targets):
        diff = torch.abs(inputs - targets)
        loss = torch.where(diff < self.delta, 0.5 * diff**2, self.delta * (diff - 0.5 * self.delta))
        return loss.mean() if self.reduction == 'mean' else loss.sum()