File size: 868 Bytes
226675b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch
import torch.nn as nn
from rscd.losses.loss_func import CELoss, FocalLoss, dice_loss, BCEDICE_loss, LOVASZ
from rscd.losses.mask2formerLoss import Mask2formerLoss
from rscd.losses.RSMambaLoss import FCCDN_loss_without_seg
class myLoss(nn.Module):
def __init__(self, param, loss_name=['CELoss'], loss_weight=[1.0], **kwargs):
super(myLoss, self).__init__()
self.loss_weight = loss_weight
self.loss = list()
for _loss in loss_name:
self.loss.append(eval(_loss)(**param[_loss],**kwargs))
def forward(self, preds, target):
loss = 0
for i in range(0, len(self.loss)):
loss += self.loss[i](preds, target) * self.loss_weight[i]
return loss
def build_loss(cfg):
loss_type = cfg.pop('type')
obj_cls = eval(loss_type)
obj = obj_cls(**cfg)
return obj
|