InPeerReview's picture
Upload 161 files
226675b verified
import torch
import torch.nn as nn
from rscd.losses.loss_func import CELoss, FocalLoss, dice_loss, BCEDICE_loss, LOVASZ
from rscd.losses.mask2formerLoss import Mask2formerLoss
from rscd.losses.RSMambaLoss import FCCDN_loss_without_seg
class myLoss(nn.Module):
def __init__(self, param, loss_name=['CELoss'], loss_weight=[1.0], **kwargs):
super(myLoss, self).__init__()
self.loss_weight = loss_weight
self.loss = list()
for _loss in loss_name:
self.loss.append(eval(_loss)(**param[_loss],**kwargs))
def forward(self, preds, target):
loss = 0
for i in range(0, len(self.loss)):
loss += self.loss[i](preds, target) * self.loss_weight[i]
return loss
def build_loss(cfg):
loss_type = cfg.pop('type')
obj_cls = eval(loss_type)
obj = obj_cls(**cfg)
return obj