|
|
import numpy as np
|
|
|
import itertools
|
|
|
from typing import Any, Dict, List, Tuple, Union
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
MaskFormer criterion.
|
|
|
"""
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from torch import nn
|
|
|
|
|
|
from rscd.losses.loss_util.criterion import SetCriterion
|
|
|
from rscd.losses.loss_util.matcher import HungarianMatcher
|
|
|
|
|
|
class Mask2formerLoss(nn.Module):
|
|
|
def __init__(self, class_weight=2.0,
|
|
|
dice_weight=5.0,
|
|
|
mask_weight=5.0,
|
|
|
no_object_weight=0.1,
|
|
|
dec_layers = 10,
|
|
|
num_classes = 1,
|
|
|
device="cuda:0"):
|
|
|
super(Mask2formerLoss, self).__init__()
|
|
|
self.device = device
|
|
|
self.class_weight = class_weight
|
|
|
self.dice_weight = dice_weight
|
|
|
self.mask_weight = mask_weight
|
|
|
self.no_object_weight = no_object_weight
|
|
|
self.dec_layers = dec_layers
|
|
|
self.num_classes = num_classes
|
|
|
|
|
|
def forward(self, preds, target):
|
|
|
|
|
|
matcher = HungarianMatcher(
|
|
|
cost_class=self.class_weight,
|
|
|
cost_mask=self.mask_weight,
|
|
|
cost_dice=self.dice_weight,
|
|
|
num_points=12544,
|
|
|
)
|
|
|
|
|
|
weight_dict = {"loss_ce": self.class_weight, "loss_mask": self.mask_weight, "loss_dice": self.dice_weight}
|
|
|
aux_weight_dict = {}
|
|
|
for i in range(self.dec_layers - 1):
|
|
|
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
|
|
weight_dict.update(aux_weight_dict)
|
|
|
|
|
|
losses = ["labels", "masks"]
|
|
|
criterion = SetCriterion(
|
|
|
num_classes=self.num_classes,
|
|
|
matcher=matcher,
|
|
|
weight_dict=weight_dict,
|
|
|
eos_coef=self.no_object_weight,
|
|
|
losses=losses,
|
|
|
num_points=12544,
|
|
|
oversample_ratio=3.0,
|
|
|
importance_sample_ratio=0.75,
|
|
|
device=torch.device(self.device)
|
|
|
)
|
|
|
|
|
|
preds["pred_masks"]= F.interpolate(
|
|
|
preds["pred_masks"],
|
|
|
scale_factor=(4, 4),
|
|
|
mode="bilinear",
|
|
|
align_corners=False,
|
|
|
)
|
|
|
|
|
|
for v in preds['aux_outputs']:
|
|
|
v['pred_masks'] = F.interpolate(
|
|
|
v["pred_masks"],
|
|
|
scale_factor=(4, 4),
|
|
|
mode="bilinear",
|
|
|
align_corners=False,
|
|
|
)
|
|
|
|
|
|
losses = criterion(preds, target)
|
|
|
weight_dict = criterion.weight_dict
|
|
|
|
|
|
loss_ce = 0.0
|
|
|
loss_dice = 0.0
|
|
|
loss_mask = 0.0
|
|
|
for k in list(losses.keys()):
|
|
|
if k in weight_dict:
|
|
|
losses[k] *= criterion.weight_dict[k]
|
|
|
if '_ce' in k:
|
|
|
loss_ce += losses[k]
|
|
|
elif '_dice' in k:
|
|
|
loss_dice += losses[k]
|
|
|
elif '_mask' in k:
|
|
|
loss_mask += losses[k]
|
|
|
else:
|
|
|
|
|
|
losses.pop(k)
|
|
|
loss = loss_ce + loss_dice + loss_mask
|
|
|
return loss |