import torch import torch.nn as nn import torch.nn.functional as F class L1Loss(nn.modules.loss._Loss): def __init__(self): super(L1Loss, self).__init__() def forward(self, sr, hr): sr_ = (sr != 0).sum().float() hr_ = (hr != 0).sum().float() l = F.l1_loss(sr_, hr_) return l